Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import uuid | |
| import threading | |
| import requests | |
| import subprocess | |
| import base64 | |
| from flask import Flask, render_template, request, jsonify, send_from_directory, abort | |
| import queue | |
| from groq import Groq | |
| import itertools | |
| # ---------- CONFIG ---------- | |
| API_KEYS = [os.environ.get(f"btz{i}") for i in range(1, 100) if os.environ.get(f"btz{i}")] | |
| if not API_KEYS: raise KeyError("No 'btzN' environment variables found (e.g., btz1, btz2). Please set them in Hugging Face secrets.") | |
| GROQ_API_KEY = os.environ.get("GROQ_KEY") | |
| if not GROQ_API_KEY: raise KeyError("GROQ_KEY environment variable not found.") | |
| GEMINI_API_KEYS = [os.environ.get(f"gmni{i}") for i in range(1, 100) if os.environ.get(f"gmni{i}")] | |
| if not GEMINI_API_KEYS: raise KeyError("No 'gmniN' environment variables found (e.g., gmni1, gmni2). Please set them in Hugging Face secrets.") | |
| gemini_key_cycler = itertools.cycle(GEMINI_API_KEYS) | |
| DEFAULT_BYTEZ_MODEL = "ali-vilab/text-to-video-ms-1.7b" | |
| # ---------- MODEL HUNTER FUNCTIONS ---------- | |
| def find_best_groq_model(api_key): | |
| try: | |
| print("🤖 Hunting for the best Groq model...") | |
| client = Groq(api_key=api_key); models = client.models.list().data | |
| available_ids = {model.id for model in models} | |
| preferred_keywords = ["llama-3.1", "llama3", "gemma2", "mixtral", "gemma"] | |
| for keyword in preferred_keywords: | |
| for model_id in available_ids: | |
| if keyword in model_id: print(f"🎯 Groq Target locked: {model_id}"); return model_id | |
| for model_id in available_ids: | |
| if all(k not in model_id for k in ["guard", "tts", "prompt"]): print(f"✅ Groq Fallback found: {model_id}"); return model_id | |
| raise ValueError("No usable Groq models found.") | |
| except Exception as e: print(f"🛑 Groq hunt failed: {e}. Using hardcoded fallback."); return "llama-3.1-8b-instant" | |
| def find_best_gemini_vision_model(api_key): | |
| try: | |
| print("🤖 Hunting for the best Gemini Vision model...") | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}" | |
| response = requests.get(url, timeout=15); response.raise_for_status() | |
| models = response.json().get("models", []) | |
| vision_models = [m["name"] for m in models if any("generateContent" in s for s in m.get("supportedGenerationMethods", [])) and ("vision" in m["name"] or "flash" in m["name"] or "1.5" in m["name"])] | |
| preferred_models = ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-pro-vision"] | |
| for preferred in preferred_models: | |
| for model_path in vision_models: | |
| if preferred in model_path: model_name = model_path.split('/')[-1]; print(f"🎯 Gemini Target locked: {model_name}"); return model_name | |
| if vision_models: model_name = vision_models[0].split('/')[-1]; print(f"✅ Gemini Fallback found: {model_name}"); return model_name | |
| raise ValueError("No usable Gemini Vision models found.") | |
| except Exception as e: print(f"🛑 Gemini hunt failed: {e}. Using hardcoded fallback."); return "gemini-1.5-flash-latest" | |
| # ---------- INITIALIZATION ---------- | |
| GROQ_MODEL = find_best_groq_model(GROQ_API_KEY) | |
| GEMINI_VISION_MODEL = find_best_gemini_vision_model(GEMINI_API_KEYS[0]) | |
| print(f"✅ Loaded {len(API_KEYS)} Bytez keys and {len(GEMINI_API_KEYS)} Gemini keys. Using Groq: {GROQ_MODEL} and Gemini: {GEMINI_VISION_MODEL}") | |
| OUTPUT_FOLDER = "output" | |
| os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
| os.makedirs("static", exist_ok=True) | |
| # ---------- APP & STATE ---------- | |
| app = Flask(__name__) | |
| progress = { "active": False, "step": 0, "total": 0, "status": "idle", "message": "", "error": None, "video_relpath": None, "live_log": [] } | |
| job_queue = queue.Queue() | |
| generated_clips_dict = {} | |
| clips_lock = threading.Lock() | |
| # ---------- HELPER FUNCTIONS ---------- | |
| def set_progress(log_message=None, **kwargs): | |
| global progress | |
| with threading.Lock(): | |
| progress.update(kwargs) | |
| if log_message: progress["live_log"].append(f"> {log_message}"); | |
| if len(progress["live_log"]) > 20: progress["live_log"].pop(0) | |
| def get_progress_copy(): | |
| with threading.Lock(): return progress.copy() | |
| def get_prompt_from_gemini(image_data, user_text, mime_type, api_key): | |
| print(f"🧠 Contacting Gemini (Creative Director) using key ...{api_key[-4:]}") | |
| try: | |
| encoded_image = base64.b64encode(image_data).decode('utf-8') | |
| gemini_api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_VISION_MODEL}:generateContent?key={api_key}" | |
| instruction_text = "You are a master AI Art Director. Your goal is to create a **vivid, rich, and visually descriptive** prompt for a text-to-video AI. Analyze the image and the user's text. Create a prompt that captures the essence, colors, and details of the scene. Focus entirely on what can be SEEN. Describe textures, lighting, and specific objects. The more visual detail, the better. Your output must be a powerful and inspiring guide for the artist." | |
| parts = [{"text": instruction_text}] | |
| if user_text: parts.append({"text": f"\nUser's instruction: '{user_text}'"}) | |
| parts.append({"inline_data": {"mime_type": mime_type, "data": encoded_image}}) | |
| payload = {"contents": [{"parts": parts}]} | |
| headers = {'Content-Type': 'application/json'} | |
| response = requests.post(gemini_api_url, headers=headers, json=payload, timeout=60); response.raise_for_status() | |
| result = response.json() | |
| if 'candidates' not in result or not result['candidates']: | |
| if 'promptFeedback' in result and 'blockReason' in result['promptFeedback']: return None, f"Gemini API Blocked: {result['promptFeedback']['blockReason']}" | |
| return None, "Gemini API returned an empty response." | |
| generated_prompt = result['candidates'][0]['content']['parts'][0]['text'].strip() | |
| print(f"🎬 Gemini's (Creative) prompt: {generated_prompt}") | |
| return generated_prompt, None | |
| except requests.exceptions.RequestException as e: return None, f"Gemini API Request Error: {e}" | |
| except Exception as e: return None, f"An unexpected error occurred with Gemini: {e}" | |
| def generate_visual_blueprint_with_groq(user_prompt, api_key, model_name): | |
| print("🎨 Contacting Groq (Quality Enhancer)...") | |
| try: | |
| client = Groq(api_key=api_key) | |
| system_prompt = "You are a prompt engineering expert, a 'Quality Enhancer' for a text-to-video AI. Your job is to take an incoming creative prompt and make it technically perfect for the AI artist. **Your primary rule: ADD powerful keywords.** Append terms like 'photorealistic, 4k, ultra realistic, sharp focus, hyper-detailed, vibrant colors, cinematic lighting' to the prompt. Refine the description to be even more visually precise. Do NOT simplify, ENHANCE. Example Input: 'A bush of pink roses.' Example Output: 'A dense bush of vibrant pink roses, hyper-detailed petals, sharp focus, cinematic lighting, photorealistic, 4k, ultra realistic.' Output ONLY the final, enhanced description." | |
| response = client.chat.completions.create(model=model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]) | |
| visual_blueprint = response.choices[0].message.content.strip() | |
| print(f"✨ Groq's (Enhanced) Blueprint: {visual_blueprint}") | |
| return visual_blueprint, None | |
| except Exception as e: return None, f"Groq API Error: {e}" | |
| # हमने यहाँ से distill_prompt_for_short_video फंक्शन को हटा दिया है। | |
| def generate_clip(prompt, idx, api_key, bytez_model): | |
| from bytez import Bytez | |
| sdk = Bytez(api_key) | |
| model = sdk.model(bytez_model) | |
| out = None | |
| err = None | |
| try: | |
| result = model.run(prompt) | |
| print(f"DEBUG: model.run() returned a '{type(result)}'. Full result: {result}") | |
| if isinstance(result, tuple) and len(result) >= 2: | |
| out = result[0] | |
| err = result[1] | |
| elif isinstance(result, tuple) and len(result) == 1: | |
| out = result[0] | |
| err = None | |
| else: | |
| out = result | |
| err = None | |
| except Exception as e: | |
| print(f"🛑 CRITICAL ERROR during model.run() call: {e}") | |
| return None, str(e) | |
| if err: | |
| return None, f"Model Error (Key ...{api_key[-4:]}): {err}" | |
| filename = f"clip_{idx}_{uuid.uuid4().hex}.mp4" | |
| filepath = os.path.join(OUTPUT_FOLDER, filename) | |
| try: | |
| if isinstance(out, bytes): | |
| with open(filepath, "wb") as f: f.write(out) | |
| elif isinstance(out, str) and out.startswith('http'): | |
| r = requests.get(out, timeout=300) | |
| r.raise_for_status() | |
| with open(filepath, "wb") as f: f.write(r.content) | |
| else: | |
| return None, f"Unexpected or empty output from model. Output type: {type(out)}" | |
| except Exception as e: | |
| return None, f"Failed to save or download the generated clip: {e}" | |
| return filepath, None | |
| def process_and_merge_clips(clip_files): | |
| if not clip_files: return None | |
| list_file = os.path.join(OUTPUT_FOLDER, f"clips_{uuid.uuid4().hex}.txt") | |
| with open(list_file, "w") as f: | |
| for c in clip_files: f.write(f"file '{os.path.abspath(c)}'\n") | |
| raw_merge_path = os.path.join(OUTPUT_FOLDER, f"final_raw_{uuid.uuid4().hex}.mp4") | |
| merge_cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", raw_merge_path] | |
| try: subprocess.run(merge_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| finally: | |
| if os.path.exists(list_file): os.remove(list_file) | |
| cinematic_path = os.path.join(OUTPUT_FOLDER, f"final_cinematic_{uuid.uuid4().hex}.mp4") | |
| vf_filters = "scale=1280:720:flags=lanczos,eq=contrast=1.1:brightness=0.05:saturation=1.15,unsharp=5:5:0.8:3:3:0.4" | |
| process_cmd = ["ffmpeg", "-i", raw_merge_path, "-vf", vf_filters, "-c:v", "libx264", "-preset", "veryfast", "-crf", "23", "-c:a", "copy", cinematic_path] | |
| try: | |
| print("🔧 Attempting cinematic post-processing..."); subprocess.run(process_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| print("✅ Cinematic processing successful!") | |
| if os.path.exists(raw_merge_path): os.remove(raw_merge_path) | |
| return cinematic_path | |
| except Exception as e: print(f"⚠️ Cinematic processing failed: {e}. Falling back to raw video."); return raw_merge_path | |
| def worker(api_key): | |
| while True: | |
| try: | |
| prompt, idx, num_clips_in_job, bytez_model = job_queue.get() | |
| clip_path, err = generate_clip(prompt, idx, api_key, bytez_model) | |
| with clips_lock: | |
| if err: print(f"⚠️ Clip {idx + 1} failed: {err}. Skipping."); set_progress(log_message=f"⚠️ Artist #{idx + 1} failed. Skipping.") | |
| else: generated_clips_dict[idx] = clip_path; set_progress(log_message=f"✅ Artist #{idx + 1} finished clip.") | |
| successful_clips = len(generated_clips_dict); p_data = get_progress_copy() | |
| base_step = 1 + (1 if p_data.get('is_image_job') else 0) | |
| set_progress(step=base_step + successful_clips, message=f"Generated {successful_clips} of {num_clips_in_job} clips...") | |
| finally: job_queue.task_done() | |
| def generate_video_job(prompt, num_clips, bytez_model): | |
| temp_clip_paths = [] | |
| try: | |
| with clips_lock: generated_clips_dict.clear() | |
| total_steps = get_progress_copy().get("total", num_clips + 2); base_step = total_steps - num_clips - 1 | |
| set_progress(active=True, step=base_step, message=f"Queueing {num_clips} clips...", log_message=f"🚀 Dispatching job to {num_clips} artists...") | |
| for i in range(num_clips): job_queue.put((prompt, i, num_clips, bytez_model)) | |
| job_queue.join() | |
| with clips_lock: | |
| sorted_clips = sorted(generated_clips_dict.items()) | |
| clip_paths = [path for _, path in sorted_clips]; temp_clip_paths.extend(clip_paths) | |
| successful_clips = len(clip_paths) | |
| print(f"✅ Generation phase complete. {successful_clips} out of {num_clips} clips succeeded.") | |
| set_progress(log_message=f"👍 {successful_clips} clips are ready.") | |
| if successful_clips == 0: raise RuntimeError("All clips failed to generate. Cannot create a video.") | |
| set_progress(step=total_steps - 1, message=f"Merging {successful_clips} successful clips...", log_message="🔧 Merging and post-processing...") | |
| final_abs_path = process_and_merge_clips(clip_paths) | |
| if not final_abs_path: raise RuntimeError("Failed to merge the generated clips.") | |
| final_rel_path = os.path.relpath(final_abs_path, start=os.getcwd()) | |
| final_message = f"✅ Video ready! ({successful_clips}/{num_clips} clips succeeded)" | |
| set_progress(step=total_steps, status="done", message=final_message, video_relpath=final_rel_path, log_message="🎉 Mission Accomplished!") | |
| except Exception as e: set_progress(status="error", error=str(e), message="Generation failed.", log_message=f"🛑 Mission Failure: {e}") | |
| finally: | |
| set_progress(active=False) | |
| for clip in temp_clip_paths: | |
| if os.path.exists(clip): os.remove(clip) | |
| # ---------- FLASK ROUTES ---------- | |
| def home(): | |
| return render_template("index.html") | |
| def start(): | |
| set_progress(live_log=[]) | |
| if get_progress_copy().get("active", False): return jsonify({"error": "A video is already being generated. Please wait."}), 429 | |
| user_prompt = request.form.get("prompt", "").strip() | |
| image_file = request.files.get("image") | |
| num_clips = max(1, min(int(request.form.get("num_clips", 3)), 20)) | |
| style = request.form.get("style", "none") | |
| bytez_model = request.form.get("bytez_model", "").strip() | |
| if not bytez_model: bytez_model = DEFAULT_BYTEZ_MODEL | |
| if not user_prompt and not image_file: return jsonify({"error": "Prompt or image cannot be empty."}), 400 | |
| initial_prompt = user_prompt | |
| is_image_job = bool(image_file) | |
| total_steps = num_clips + 2 + (1 if is_image_job else 0) | |
| set_progress(is_image_job=is_image_job, total=total_steps) | |
| if is_image_job: | |
| try: | |
| image_data = image_file.read() | |
| mime_type = image_file.mimetype | |
| set_progress(status="running", message="Initializing...", error=None, active=True, step=0, log_message="🧠 Director (Gemini) analyzing image...") | |
| selected_gemini_key = next(gemini_key_cycler) | |
| gemini_prompt, err = get_prompt_from_gemini(image_data, user_prompt, mime_type, selected_gemini_key) | |
| if err: | |
| set_progress(status="error", error=err, active=False, log_message=f"🛑 Gemini Failure: {err}"); return jsonify({"error": err}), 500 | |
| initial_prompt = gemini_prompt | |
| set_progress(log_message=f"🎬 Gemini's Idea: \"{initial_prompt[:80]}...\"") | |
| except Exception as e: | |
| err_msg = f"Failed to process image: {e}"; set_progress(status="error", error=err_msg, active=False, log_message=f"🛑 Image Error: {e}"); return jsonify({"error": err_msg}), 500 | |
| current_step = 1 if is_image_job else 0 | |
| set_progress(status="running", message="Creating blueprint...", active=True, step=current_step, log_message="🎨 Quality Enhancer (Groq) creating blueprint...") | |
| visual_blueprint, err = generate_visual_blueprint_with_groq(initial_prompt, GROQ_API_KEY, GROQ_MODEL) | |
| if err: | |
| set_progress(status="error", error=err, active=False, log_message=f"🛑 Groq Enhancer Failure: {err}"); return jsonify({"error": err}), 500 | |
| set_progress(log_message=f"✨ Enhanced Blueprint: \"{visual_blueprint[:80]}...\"") | |
| ### --- यह हिस्सा बदला गया है --- ### | |
| # हमने यहाँ से "Prompt Distiller" को हटा दिया है। | |
| # अब हम विस्तृत प्रॉम्प्ट को सीधे उपयोग करेंगे। | |
| negative_keywords = "blurry, deformed, ugly, bad anatomy, watermark, noise, grain, low quality, distortion, glitch, pixelated, artifacts" | |
| # अब हम `visual_blueprint` का सीधा उपयोग कर रहे हैं। | |
| final_prompt = f"{visual_blueprint}, {negative_keywords}" | |
| print(f"🚀 Final Prompt for Bytez (Model: {bytez_model}): {final_prompt}") | |
| job_thread = threading.Thread(target=generate_video_job, args=(final_prompt, num_clips, bytez_model), daemon=True) | |
| job_thread.start() | |
| return jsonify({"ok": True, "message": "Job started."}) | |
| def get_progress_endpoint(): | |
| return jsonify(get_progress_copy()) | |
| def serve_output_file(filename): | |
| if ".." in filename or filename.startswith("/"): abort(400) | |
| return send_from_directory(OUTPUT_FOLDER, filename) | |
| def serve_manifest(): | |
| return send_from_directory('static', 'manifest.json') | |
| def serve_sw(): | |
| return send_from_directory('static', 'service-worker.js') | |
| # ---------- RUN ---------- | |
| print(f"Starting {len(API_KEYS)} worker threads for this process...") | |
| for api_key in API_KEYS: | |
| worker_thread = threading.Thread(target=worker, args=(api_key,), daemon=True) | |
| worker_thread.start() |