vibe_sip / lightweight_conversational_llm.py
artush-habetyan's picture
Upload 3 files
e8edbd7 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Try to import BitsAndBytesConfig, but don't fail if not available
try:
from transformers import BitsAndBytesConfig
QUANTIZATION_AVAILABLE = True
except ImportError:
QUANTIZATION_AVAILABLE = False
logging.warning("BitsAndBytesConfig not available, quantization disabled")
class LightweightConversationalLLM:
def __init__(self, model_name="HuggingFaceTB/SmolLM-1.7B-Instruct"):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.setup_model()
def setup_model(self):
try:
# Try quantization first if available, fallback to regular loading
if QUANTIZATION_AVAILABLE:
try:
# Configure 4-bit quantization for memory efficiency
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Load model with quantization
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
logging.info(f"Successfully loaded {self.model_name} with 4-bit quantization")
except Exception as quant_error:
logging.warning(f"4-bit quantization failed: {quant_error}")
logging.info("Falling back to regular model loading...")
# Fallback to regular loading without quantization
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
torch_dtype=torch.float16, # Use float16 for memory efficiency
trust_remote_code=True,
low_cpu_mem_usage=True
)
logging.info(f"Successfully loaded {self.model_name} without quantization")
else:
# Load without quantization
logging.info("Loading model without quantization (bitsandbytes not available)")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
torch_dtype=torch.float16, # Use float16 for memory efficiency
trust_remote_code=True,
low_cpu_mem_usage=True
)
logging.info(f"Successfully loaded {self.model_name} without quantization")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
except Exception as e:
logging.warning(f"Failed to load {self.model_name}: {e}")
self.model = None
self.tokenizer = None
def generate_response(self, venue_context, user_query, max_length=400):
if not self.model or not self.tokenizer:
return "I can help you find venues, but conversational features are currently unavailable."
try:
# Create a focused prompt for venue recommendations
prompt = f"""You are a helpful Yerevan venue assistant. Based on the venue information provided, give a brief, friendly response.
Venue Context: {venue_context[:800]}...
User: {user_query}
Assistant:"""
inputs = self.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=max_length,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=3
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
assistant_response = response.split("Assistant:")[-1].strip()
return assistant_response[:max_length] if len(assistant_response) > max_length else assistant_response
except Exception as e:
logging.error(f"Error generating response: {e}")
return "I found the venues you requested, but had trouble generating a conversational response."