Spaces:
Running
Running
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import gradio as gr | |
| import asyncio | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor | |
| import time | |
| import re | |
| # Global model storage - all loaded simultaneously on L40S | |
| models = {} | |
| tokenizers = {} | |
| model_loaded = {} | |
| # Speed-optimized configurations for L40S | |
| MODEL_CONFIGS = { | |
| "Llama-1 7B": { | |
| "model_id": "huggyllama/llama-7b", | |
| "load_in_4bit": False, # Use full precision for speed | |
| "torch_dtype": torch.bfloat16, | |
| "device_map": {"": 0}, # Force to GPU 0 | |
| }, | |
| "Llama-2 7B Chat": { | |
| "model_id": "meta-llama/Llama-2-7b-chat-hf", | |
| "load_in_4bit": False, | |
| "torch_dtype": torch.bfloat16, | |
| "device_map": {"": 0}, | |
| }, | |
| "Llama-3.2 3B": { | |
| "model_id": "meta-llama/Llama-3.2-3B-Instruct", | |
| "load_in_4bit": False, | |
| "torch_dtype": torch.bfloat16, | |
| "device_map": {"": 0}, | |
| } | |
| } | |
| def load_all_models(): | |
| """Load all models simultaneously - L40S has enough VRAM""" | |
| global models, tokenizers, model_loaded | |
| print("Loading all models simultaneously on L40S...") | |
| start_time = time.time() | |
| for model_name, config in MODEL_CONFIGS.items(): | |
| print(f"Loading {model_name}...") | |
| try: | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(config["model_id"]) | |
| # Handle tokenizer setup differently for Llama-1 | |
| if "Llama-1" in model_name: | |
| # Llama-1 doesn't have a pad token by default | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" # Important for generation | |
| else: | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with basic optimizations (ZeroGPU compatible) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config["model_id"], | |
| torch_dtype=config["torch_dtype"], | |
| device_map=config["device_map"], | |
| trust_remote_code=True, | |
| use_cache=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| models[model_name] = model | |
| tokenizers[model_name] = tokenizer | |
| model_loaded[model_name] = True | |
| print(f"β {model_name} loaded successfully") | |
| except Exception as e: | |
| print(f"β Error loading {model_name}: {e}") | |
| model_loaded[model_name] = False | |
| total_time = time.time() - start_time | |
| print(f"All models loaded in {total_time:.2f} seconds") | |
| print(f"GPU Memory used: {torch.cuda.memory_allocated()/1024**3:.2f}GB") | |
| def format_prompt(input_question, model_name): | |
| """Format prompt based on model type""" | |
| if "Llama-2" in model_name: | |
| system_msg = "You are a helpful assistant. Answer questions clearly and concisely." | |
| return f"<s>[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n{input_question} [/INST]" | |
| elif "3.2" in model_name: | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant. Answer questions clearly and concisely."}, | |
| {"role": "user", "content": input_question} | |
| ] | |
| return tokenizers[model_name].apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| else: # Llama-1 | |
| # Better prompt format for Llama-1 to encourage stopping | |
| return f"Question: {input_question}\n\nResponse: " | |
| def clean_llama1_response(response_text, original_prompt): | |
| """Clean up Llama-1 response to prevent loops and cut at natural stopping points""" | |
| # Remove the original prompt if it appears in the response | |
| if original_prompt in response_text: | |
| response_text = response_text.replace(original_prompt, "").strip() | |
| # Split by common conversation markers and take first part | |
| stop_markers = [ | |
| "\n\nHuman:", "\n\nUser:", "\n\nQuestion:", | |
| "\n\nQ:", "\n\nA:", "\nHuman:", "\nUser:", | |
| "Human:", "User:", "Question:", "###", "Answer:" | |
| ] | |
| for marker in stop_markers: | |
| if marker in response_text: | |
| response_text = response_text.split(marker)[0].strip() | |
| break | |
| # Remove repetitive patterns (simple heuristic) | |
| lines = response_text.split('\n') | |
| cleaned_lines = [] | |
| seen_lines = set() | |
| for line in lines: | |
| line = line.strip() | |
| if line and line not in seen_lines: | |
| cleaned_lines.append(line) | |
| seen_lines.add(line) | |
| elif line in seen_lines: | |
| # Stop if we see repetition | |
| break | |
| response_text = '\n'.join(cleaned_lines) | |
| # Truncate if too long (another safety measure) | |
| if len(response_text) > 1000: | |
| response_text = response_text[:1000] + "..." | |
| return response_text.strip() | |
| def generate_single_response(model_name, input_question): | |
| """Generate response from a single model - optimized for speed""" | |
| if not model_loaded.get(model_name, False): | |
| return f"β {model_name} not available" | |
| try: | |
| model = models[model_name] | |
| tokenizer = tokenizers[model_name] | |
| # Format prompt | |
| formatted_prompt = format_prompt(input_question, model_name) | |
| # Tokenize with speed optimizations | |
| inputs = tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| padding=False | |
| ).to(model.device) | |
| # Generate with speed-focused settings | |
| with torch.no_grad(): | |
| start_time = time.time() | |
| outputs = model.generate( | |
| **inputs, | |
| #**generation_kwargs, | |
| max_new_tokens=512, | |
| ) | |
| generation_time = time.time() - start_time | |
| # Extract response | |
| response_tokens = outputs[0][inputs['input_ids'].shape[1]:] | |
| response = tokenizer.decode(response_tokens, skip_special_tokens=True) | |
| # Special cleaning for Llama-1 | |
| if "Llama-1" in model_name: | |
| response = clean_llama1_response(response, formatted_prompt) | |
| return response | |
| except Exception as e: | |
| return f"β Error with {model_name}: {str(e)}" | |
| def process_all_models_parallel(input_question): | |
| """Process all models in parallel for maximum speed""" | |
| # if not input_question.strip(): | |
| # return "β Please enter a question", "β Please enter a question", "β Please enter a question" | |
| # start_time = time.time() | |
| # # Use ThreadPoolExecutor for parallel processing | |
| # with ThreadPoolExecutor(max_workers=3) as executor: | |
| # # Submit all tasks simultaneously | |
| # futures = { | |
| # executor.submit(generate_single_response, model_name, input_question): model_name | |
| # for model_name in MODEL_CONFIGS.keys() | |
| # } | |
| # # Collect results as they complete | |
| # results = {} | |
| # for future in futures: | |
| # model_name = futures[future] | |
| # try: | |
| # result = future.result(timeout=45) # Longer timeout for Llama-1 | |
| # results[model_name] = result | |
| # except Exception as e: | |
| # results[model_name] = f"β Timeout or error for {model_name}: {str(e)}" | |
| # total_time = time.time() - start_time | |
| # # Add total timing to first response | |
| # llama1_response = results.get("Llama-1 7B", "β Error") | |
| # return ( | |
| # llama1_response, | |
| # results.get("Llama-2 7B Chat", "β Error"), | |
| # results.get("Llama-3.2 3B", "β Error") | |
| # ) | |
| llama1_response = generate_single_response("Llama-1 7B", input_question) | |
| llama2_response = generate_single_response("Llama-2 7B Chat", input_question) | |
| llama3_response = generate_single_response("Llama-3.2 3B", input_question) | |
| return llama1_response, llama2_response, llama3_response | |
| # def benchmark_models(): | |
| # """Benchmark all models with a test question""" | |
| # test_question = "What is 2+2? Please provide a brief answer." | |
| # print("πββοΈ Running benchmark...") | |
| # start_time = time.time() | |
| # results = process_all_models_parallel(test_question) | |
| # total_time = time.time() - start_time | |
| # print(f"Benchmark completed in {total_time:.2f}s") | |
| # return f"Benchmark completed! All models ready. Total time: {total_time:.2f}s" | |
| def create_interface(): | |
| """Create speed-optimized Gradio interface""" | |
| with gr.Blocks(title="Speed-Optimized Multi-Llama", theme=gr.themes.Glass()) as demo: | |
| gr.Markdown("NOTE: Llama-1 7b is NOT a chat model - behaviour in Question-Answering tasks is erratic!") | |
| # with gr.Row(): | |
| # benchmark_btn = gr.Button("πββοΈ Run Benchmark", variant="secondary", size="sm") | |
| # benchmark_output = gr.Textbox(label="Benchmark Results", visible=False) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Enter your question", | |
| placeholder="What is the meaning of life?", | |
| lines=2, | |
| max_lines=4 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Ask All Models", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| # Real-time responses in columns for better UX | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π¦ Llama-1 7B") | |
| output1 = gr.Textbox( | |
| label="Response", | |
| interactive=False, | |
| lines=8, | |
| max_lines=15, | |
| show_label=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### π¦ Llama-2 7B Chat") | |
| output2 = gr.Textbox( | |
| label="Response", | |
| interactive=False, | |
| lines=8, | |
| max_lines=15, | |
| show_label=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### π¦ Llama-3.2 3B") | |
| output3 = gr.Textbox( | |
| label="Response", | |
| interactive=False, | |
| lines=8, | |
| max_lines=15, | |
| show_label=False | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=process_all_models_parallel, | |
| inputs=[question_input], | |
| outputs=[output1, output2, output3], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", ""), | |
| outputs=[question_input, output1, output2, output3] | |
| ) | |
| return demo | |
| # Load all models at startup | |
| load_all_models() | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| ) |