knowledge-cutoff / knowledge_cutoff_demo.py
willsh1997's picture
:clown_face: clown
fe47661
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
import time
import re
# Global model storage - all loaded simultaneously on L40S
models = {}
tokenizers = {}
model_loaded = {}
# Speed-optimized configurations for L40S
MODEL_CONFIGS = {
"Llama-1 7B": {
"model_id": "huggyllama/llama-7b",
"load_in_4bit": False, # Use full precision for speed
"torch_dtype": torch.bfloat16,
"device_map": {"": 0}, # Force to GPU 0
},
"Llama-2 7B Chat": {
"model_id": "meta-llama/Llama-2-7b-chat-hf",
"load_in_4bit": False,
"torch_dtype": torch.bfloat16,
"device_map": {"": 0},
},
"Llama-3.2 3B": {
"model_id": "meta-llama/Llama-3.2-3B-Instruct",
"load_in_4bit": False,
"torch_dtype": torch.bfloat16,
"device_map": {"": 0},
}
}
def load_all_models():
"""Load all models simultaneously - L40S has enough VRAM"""
global models, tokenizers, model_loaded
print("Loading all models simultaneously on L40S...")
start_time = time.time()
for model_name, config in MODEL_CONFIGS.items():
print(f"Loading {model_name}...")
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
# Handle tokenizer setup differently for Llama-1
if "Llama-1" in model_name:
# Llama-1 doesn't have a pad token by default
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Important for generation
else:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with basic optimizations (ZeroGPU compatible)
model = AutoModelForCausalLM.from_pretrained(
config["model_id"],
torch_dtype=config["torch_dtype"],
device_map=config["device_map"],
trust_remote_code=True,
use_cache=True,
low_cpu_mem_usage=True,
)
models[model_name] = model
tokenizers[model_name] = tokenizer
model_loaded[model_name] = True
print(f"βœ… {model_name} loaded successfully")
except Exception as e:
print(f"❌ Error loading {model_name}: {e}")
model_loaded[model_name] = False
total_time = time.time() - start_time
print(f"All models loaded in {total_time:.2f} seconds")
print(f"GPU Memory used: {torch.cuda.memory_allocated()/1024**3:.2f}GB")
def format_prompt(input_question, model_name):
"""Format prompt based on model type"""
if "Llama-2" in model_name:
system_msg = "You are a helpful assistant. Answer questions clearly and concisely."
return f"<s>[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n{input_question} [/INST]"
elif "3.2" in model_name:
messages = [
{"role": "system", "content": "You are a helpful assistant. Answer questions clearly and concisely."},
{"role": "user", "content": input_question}
]
return tokenizers[model_name].apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else: # Llama-1
# Better prompt format for Llama-1 to encourage stopping
return f"Question: {input_question}\n\nResponse: "
def clean_llama1_response(response_text, original_prompt):
"""Clean up Llama-1 response to prevent loops and cut at natural stopping points"""
# Remove the original prompt if it appears in the response
if original_prompt in response_text:
response_text = response_text.replace(original_prompt, "").strip()
# Split by common conversation markers and take first part
stop_markers = [
"\n\nHuman:", "\n\nUser:", "\n\nQuestion:",
"\n\nQ:", "\n\nA:", "\nHuman:", "\nUser:",
"Human:", "User:", "Question:", "###", "Answer:"
]
for marker in stop_markers:
if marker in response_text:
response_text = response_text.split(marker)[0].strip()
break
# Remove repetitive patterns (simple heuristic)
lines = response_text.split('\n')
cleaned_lines = []
seen_lines = set()
for line in lines:
line = line.strip()
if line and line not in seen_lines:
cleaned_lines.append(line)
seen_lines.add(line)
elif line in seen_lines:
# Stop if we see repetition
break
response_text = '\n'.join(cleaned_lines)
# Truncate if too long (another safety measure)
if len(response_text) > 1000:
response_text = response_text[:1000] + "..."
return response_text.strip()
def generate_single_response(model_name, input_question):
"""Generate response from a single model - optimized for speed"""
if not model_loaded.get(model_name, False):
return f"❌ {model_name} not available"
try:
model = models[model_name]
tokenizer = tokenizers[model_name]
# Format prompt
formatted_prompt = format_prompt(input_question, model_name)
# Tokenize with speed optimizations
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=False
).to(model.device)
# Generate with speed-focused settings
with torch.no_grad():
start_time = time.time()
outputs = model.generate(
**inputs,
#**generation_kwargs,
max_new_tokens=512,
)
generation_time = time.time() - start_time
# Extract response
response_tokens = outputs[0][inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_tokens, skip_special_tokens=True)
# Special cleaning for Llama-1
if "Llama-1" in model_name:
response = clean_llama1_response(response, formatted_prompt)
return response
except Exception as e:
return f"❌ Error with {model_name}: {str(e)}"
@spaces.GPU
def process_all_models_parallel(input_question):
"""Process all models in parallel for maximum speed"""
# if not input_question.strip():
# return "❌ Please enter a question", "❌ Please enter a question", "❌ Please enter a question"
# start_time = time.time()
# # Use ThreadPoolExecutor for parallel processing
# with ThreadPoolExecutor(max_workers=3) as executor:
# # Submit all tasks simultaneously
# futures = {
# executor.submit(generate_single_response, model_name, input_question): model_name
# for model_name in MODEL_CONFIGS.keys()
# }
# # Collect results as they complete
# results = {}
# for future in futures:
# model_name = futures[future]
# try:
# result = future.result(timeout=45) # Longer timeout for Llama-1
# results[model_name] = result
# except Exception as e:
# results[model_name] = f"❌ Timeout or error for {model_name}: {str(e)}"
# total_time = time.time() - start_time
# # Add total timing to first response
# llama1_response = results.get("Llama-1 7B", "❌ Error")
# return (
# llama1_response,
# results.get("Llama-2 7B Chat", "❌ Error"),
# results.get("Llama-3.2 3B", "❌ Error")
# )
llama1_response = generate_single_response("Llama-1 7B", input_question)
llama2_response = generate_single_response("Llama-2 7B Chat", input_question)
llama3_response = generate_single_response("Llama-3.2 3B", input_question)
return llama1_response, llama2_response, llama3_response
# def benchmark_models():
# """Benchmark all models with a test question"""
# test_question = "What is 2+2? Please provide a brief answer."
# print("πŸƒβ€β™‚οΈ Running benchmark...")
# start_time = time.time()
# results = process_all_models_parallel(test_question)
# total_time = time.time() - start_time
# print(f"Benchmark completed in {total_time:.2f}s")
# return f"Benchmark completed! All models ready. Total time: {total_time:.2f}s"
def create_interface():
"""Create speed-optimized Gradio interface"""
with gr.Blocks(title="Speed-Optimized Multi-Llama", theme=gr.themes.Glass()) as demo:
gr.Markdown("NOTE: Llama-1 7b is NOT a chat model - behaviour in Question-Answering tasks is erratic!")
# with gr.Row():
# benchmark_btn = gr.Button("πŸƒβ€β™‚οΈ Run Benchmark", variant="secondary", size="sm")
# benchmark_output = gr.Textbox(label="Benchmark Results", visible=False)
with gr.Row():
question_input = gr.Textbox(
label="Enter your question",
placeholder="What is the meaning of life?",
lines=2,
max_lines=4
)
with gr.Row():
submit_btn = gr.Button("Ask All Models", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
# Real-time responses in columns for better UX
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ¦™ Llama-1 7B")
output1 = gr.Textbox(
label="Response",
interactive=False,
lines=8,
max_lines=15,
show_label=False
)
with gr.Column():
gr.Markdown("### πŸ¦™ Llama-2 7B Chat")
output2 = gr.Textbox(
label="Response",
interactive=False,
lines=8,
max_lines=15,
show_label=False
)
with gr.Column():
gr.Markdown("### πŸ¦™ Llama-3.2 3B")
output3 = gr.Textbox(
label="Response",
interactive=False,
lines=8,
max_lines=15,
show_label=False
)
# Event handlers
submit_btn.click(
fn=process_all_models_parallel,
inputs=[question_input],
outputs=[output1, output2, output3],
)
clear_btn.click(
fn=lambda: ("", "", "", ""),
outputs=[question_input, output1, output2, output3]
)
return demo
# Load all models at startup
load_all_models()
if __name__ == "__main__":
demo = create_interface()
demo.launch(
)