bytez-spark / app.py
diwash-barla's picture
Old prompt
c960247 verified
raw
history blame
17.5 kB
import os
import json
import uuid
import threading
import requests
import subprocess
import base64
from flask import Flask, render_template, request, jsonify, send_from_directory, abort
import queue
from groq import Groq
import itertools
# ---------- CONFIG ----------
API_KEYS = [os.environ.get(f"btz{i}") for i in range(1, 100) if os.environ.get(f"btz{i}")]
if not API_KEYS: raise KeyError("No 'btzN' environment variables found (e.g., btz1, btz2). Please set them in Hugging Face secrets.")
GROQ_API_KEY = os.environ.get("GROQ_KEY")
if not GROQ_API_KEY: raise KeyError("GROQ_KEY environment variable not found.")
GEMINI_API_KEYS = [os.environ.get(f"gmni{i}") for i in range(1, 100) if os.environ.get(f"gmni{i}")]
if not GEMINI_API_KEYS: raise KeyError("No 'gmniN' environment variables found (e.g., gmni1, gmni2). Please set them in Hugging Face secrets.")
gemini_key_cycler = itertools.cycle(GEMINI_API_KEYS)
DEFAULT_BYTEZ_MODEL = "ali-vilab/text-to-video-ms-1.7b"
# ---------- MODEL HUNTER FUNCTIONS ----------
def find_best_groq_model(api_key):
try:
print("🤖 Hunting for the best Groq model...")
client = Groq(api_key=api_key); models = client.models.list().data
available_ids = {model.id for model in models}
preferred_keywords = ["llama-3.1", "llama3", "gemma2", "mixtral", "gemma"]
for keyword in preferred_keywords:
for model_id in available_ids:
if keyword in model_id: print(f"🎯 Groq Target locked: {model_id}"); return model_id
for model_id in available_ids:
if all(k not in model_id for k in ["guard", "tts", "prompt"]): print(f"✅ Groq Fallback found: {model_id}"); return model_id
raise ValueError("No usable Groq models found.")
except Exception as e: print(f"🛑 Groq hunt failed: {e}. Using hardcoded fallback."); return "llama-3.1-8b-instant"
def find_best_gemini_vision_model(api_key):
try:
print("🤖 Hunting for the best Gemini Vision model...")
url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}"
response = requests.get(url, timeout=15); response.raise_for_status()
models = response.json().get("models", [])
vision_models = [m["name"] for m in models if any("generateContent" in s for s in m.get("supportedGenerationMethods", [])) and ("vision" in m["name"] or "flash" in m["name"] or "1.5" in m["name"])]
preferred_models = ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-pro-vision"]
for preferred in preferred_models:
for model_path in vision_models:
if preferred in model_path: model_name = model_path.split('/')[-1]; print(f"🎯 Gemini Target locked: {model_name}"); return model_name
if vision_models: model_name = vision_models[0].split('/')[-1]; print(f"✅ Gemini Fallback found: {model_name}"); return model_name
raise ValueError("No usable Gemini Vision models found.")
except Exception as e: print(f"🛑 Gemini hunt failed: {e}. Using hardcoded fallback."); return "gemini-1.5-flash-latest"
# ---------- INITIALIZATION ----------
GROQ_MODEL = find_best_groq_model(GROQ_API_KEY)
GEMINI_VISION_MODEL = find_best_gemini_vision_model(GEMINI_API_KEYS[0])
print(f"✅ Loaded {len(API_KEYS)} Bytez keys and {len(GEMINI_API_KEYS)} Gemini keys. Using Groq: {GROQ_MODEL} and Gemini: {GEMINI_VISION_MODEL}")
OUTPUT_FOLDER = "output"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
os.makedirs("static", exist_ok=True)
# ---------- APP & STATE ----------
app = Flask(__name__)
progress = { "active": False, "step": 0, "total": 0, "status": "idle", "message": "", "error": None, "video_relpath": None, "live_log": [] }
job_queue = queue.Queue()
generated_clips_dict = {}
clips_lock = threading.Lock()
# ---------- HELPER FUNCTIONS ----------
def set_progress(log_message=None, **kwargs):
global progress
with threading.Lock():
progress.update(kwargs)
if log_message: progress["live_log"].append(f"> {log_message}");
if len(progress["live_log"]) > 20: progress["live_log"].pop(0)
def get_progress_copy():
with threading.Lock(): return progress.copy()
def get_prompt_from_gemini(image_data, user_text, mime_type, api_key):
print(f"🧠 Contacting Gemini (Creative Director) using key ...{api_key[-4:]}")
try:
encoded_image = base64.b64encode(image_data).decode('utf-8')
gemini_api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_VISION_MODEL}:generateContent?key={api_key}"
instruction_text = "You are a master AI Art Director. Your goal is to create a **vivid, rich, and visually descriptive** prompt for a text-to-video AI. Analyze the image and the user's text. Create a prompt that captures the essence, colors, and details of the scene. Focus entirely on what can be SEEN. Describe textures, lighting, and specific objects. The more visual detail, the better. Your output must be a powerful and inspiring guide for the artist."
parts = [{"text": instruction_text}]
if user_text: parts.append({"text": f"\nUser's instruction: '{user_text}'"})
parts.append({"inline_data": {"mime_type": mime_type, "data": encoded_image}})
payload = {"contents": [{"parts": parts}]}
headers = {'Content-Type': 'application/json'}
response = requests.post(gemini_api_url, headers=headers, json=payload, timeout=60); response.raise_for_status()
result = response.json()
if 'candidates' not in result or not result['candidates']:
if 'promptFeedback' in result and 'blockReason' in result['promptFeedback']: return None, f"Gemini API Blocked: {result['promptFeedback']['blockReason']}"
return None, "Gemini API returned an empty response."
generated_prompt = result['candidates'][0]['content']['parts'][0]['text'].strip()
print(f"🎬 Gemini's (Creative) prompt: {generated_prompt}")
return generated_prompt, None
except requests.exceptions.RequestException as e: return None, f"Gemini API Request Error: {e}"
except Exception as e: return None, f"An unexpected error occurred with Gemini: {e}"
def generate_visual_blueprint_with_groq(user_prompt, api_key, model_name):
print("🎨 Contacting Groq (Quality Enhancer)...")
try:
client = Groq(api_key=api_key)
system_prompt = "You are a prompt engineering expert, a 'Quality Enhancer' for a text-to-video AI. Your job is to take an incoming creative prompt and make it technically perfect for the AI artist. **Your primary rule: ADD powerful keywords.** Append terms like 'photorealistic, 4k, ultra realistic, sharp focus, hyper-detailed, vibrant colors, cinematic lighting' to the prompt. Refine the description to be even more visually precise. Do NOT simplify, ENHANCE. Example Input: 'A bush of pink roses.' Example Output: 'A dense bush of vibrant pink roses, hyper-detailed petals, sharp focus, cinematic lighting, photorealistic, 4k, ultra realistic.' Output ONLY the final, enhanced description."
response = client.chat.completions.create(model=model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}])
visual_blueprint = response.choices[0].message.content.strip()
print(f"✨ Groq's (Enhanced) Blueprint: {visual_blueprint}")
return visual_blueprint, None
except Exception as e: return None, f"Groq API Error: {e}"
# हमने यहाँ से distill_prompt_for_short_video फंक्शन को हटा दिया है।
def generate_clip(prompt, idx, api_key, bytez_model):
from bytez import Bytez
sdk = Bytez(api_key)
model = sdk.model(bytez_model)
out = None
err = None
try:
result = model.run(prompt)
print(f"DEBUG: model.run() returned a '{type(result)}'. Full result: {result}")
if isinstance(result, tuple) and len(result) >= 2:
out = result[0]
err = result[1]
elif isinstance(result, tuple) and len(result) == 1:
out = result[0]
err = None
else:
out = result
err = None
except Exception as e:
print(f"🛑 CRITICAL ERROR during model.run() call: {e}")
return None, str(e)
if err:
return None, f"Model Error (Key ...{api_key[-4:]}): {err}"
filename = f"clip_{idx}_{uuid.uuid4().hex}.mp4"
filepath = os.path.join(OUTPUT_FOLDER, filename)
try:
if isinstance(out, bytes):
with open(filepath, "wb") as f: f.write(out)
elif isinstance(out, str) and out.startswith('http'):
r = requests.get(out, timeout=300)
r.raise_for_status()
with open(filepath, "wb") as f: f.write(r.content)
else:
return None, f"Unexpected or empty output from model. Output type: {type(out)}"
except Exception as e:
return None, f"Failed to save or download the generated clip: {e}"
return filepath, None
def process_and_merge_clips(clip_files):
if not clip_files: return None
list_file = os.path.join(OUTPUT_FOLDER, f"clips_{uuid.uuid4().hex}.txt")
with open(list_file, "w") as f:
for c in clip_files: f.write(f"file '{os.path.abspath(c)}'\n")
raw_merge_path = os.path.join(OUTPUT_FOLDER, f"final_raw_{uuid.uuid4().hex}.mp4")
merge_cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", raw_merge_path]
try: subprocess.run(merge_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
finally:
if os.path.exists(list_file): os.remove(list_file)
cinematic_path = os.path.join(OUTPUT_FOLDER, f"final_cinematic_{uuid.uuid4().hex}.mp4")
vf_filters = "scale=1280:720:flags=lanczos,eq=contrast=1.1:brightness=0.05:saturation=1.15,unsharp=5:5:0.8:3:3:0.4"
process_cmd = ["ffmpeg", "-i", raw_merge_path, "-vf", vf_filters, "-c:v", "libx264", "-preset", "veryfast", "-crf", "23", "-c:a", "copy", cinematic_path]
try:
print("🔧 Attempting cinematic post-processing..."); subprocess.run(process_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("✅ Cinematic processing successful!")
if os.path.exists(raw_merge_path): os.remove(raw_merge_path)
return cinematic_path
except Exception as e: print(f"⚠️ Cinematic processing failed: {e}. Falling back to raw video."); return raw_merge_path
def worker(api_key):
while True:
try:
prompt, idx, num_clips_in_job, bytez_model = job_queue.get()
clip_path, err = generate_clip(prompt, idx, api_key, bytez_model)
with clips_lock:
if err: print(f"⚠️ Clip {idx + 1} failed: {err}. Skipping."); set_progress(log_message=f"⚠️ Artist #{idx + 1} failed. Skipping.")
else: generated_clips_dict[idx] = clip_path; set_progress(log_message=f"✅ Artist #{idx + 1} finished clip.")
successful_clips = len(generated_clips_dict); p_data = get_progress_copy()
base_step = 1 + (1 if p_data.get('is_image_job') else 0)
set_progress(step=base_step + successful_clips, message=f"Generated {successful_clips} of {num_clips_in_job} clips...")
finally: job_queue.task_done()
def generate_video_job(prompt, num_clips, bytez_model):
temp_clip_paths = []
try:
with clips_lock: generated_clips_dict.clear()
total_steps = get_progress_copy().get("total", num_clips + 2); base_step = total_steps - num_clips - 1
set_progress(active=True, step=base_step, message=f"Queueing {num_clips} clips...", log_message=f"🚀 Dispatching job to {num_clips} artists...")
for i in range(num_clips): job_queue.put((prompt, i, num_clips, bytez_model))
job_queue.join()
with clips_lock:
sorted_clips = sorted(generated_clips_dict.items())
clip_paths = [path for _, path in sorted_clips]; temp_clip_paths.extend(clip_paths)
successful_clips = len(clip_paths)
print(f"✅ Generation phase complete. {successful_clips} out of {num_clips} clips succeeded.")
set_progress(log_message=f"👍 {successful_clips} clips are ready.")
if successful_clips == 0: raise RuntimeError("All clips failed to generate. Cannot create a video.")
set_progress(step=total_steps - 1, message=f"Merging {successful_clips} successful clips...", log_message="🔧 Merging and post-processing...")
final_abs_path = process_and_merge_clips(clip_paths)
if not final_abs_path: raise RuntimeError("Failed to merge the generated clips.")
final_rel_path = os.path.relpath(final_abs_path, start=os.getcwd())
final_message = f"✅ Video ready! ({successful_clips}/{num_clips} clips succeeded)"
set_progress(step=total_steps, status="done", message=final_message, video_relpath=final_rel_path, log_message="🎉 Mission Accomplished!")
except Exception as e: set_progress(status="error", error=str(e), message="Generation failed.", log_message=f"🛑 Mission Failure: {e}")
finally:
set_progress(active=False)
for clip in temp_clip_paths:
if os.path.exists(clip): os.remove(clip)
# ---------- FLASK ROUTES ----------
@app.route("/", methods=["GET"])
def home():
return render_template("index.html")
@app.route("/start", methods=["POST"])
def start():
set_progress(live_log=[])
if get_progress_copy().get("active", False): return jsonify({"error": "A video is already being generated. Please wait."}), 429
user_prompt = request.form.get("prompt", "").strip()
image_file = request.files.get("image")
num_clips = max(1, min(int(request.form.get("num_clips", 3)), 20))
style = request.form.get("style", "none")
bytez_model = request.form.get("bytez_model", "").strip()
if not bytez_model: bytez_model = DEFAULT_BYTEZ_MODEL
if not user_prompt and not image_file: return jsonify({"error": "Prompt or image cannot be empty."}), 400
initial_prompt = user_prompt
is_image_job = bool(image_file)
total_steps = num_clips + 2 + (1 if is_image_job else 0)
set_progress(is_image_job=is_image_job, total=total_steps)
if is_image_job:
try:
image_data = image_file.read()
mime_type = image_file.mimetype
set_progress(status="running", message="Initializing...", error=None, active=True, step=0, log_message="🧠 Director (Gemini) analyzing image...")
selected_gemini_key = next(gemini_key_cycler)
gemini_prompt, err = get_prompt_from_gemini(image_data, user_prompt, mime_type, selected_gemini_key)
if err:
set_progress(status="error", error=err, active=False, log_message=f"🛑 Gemini Failure: {err}"); return jsonify({"error": err}), 500
initial_prompt = gemini_prompt
set_progress(log_message=f"🎬 Gemini's Idea: \"{initial_prompt[:80]}...\"")
except Exception as e:
err_msg = f"Failed to process image: {e}"; set_progress(status="error", error=err_msg, active=False, log_message=f"🛑 Image Error: {e}"); return jsonify({"error": err_msg}), 500
current_step = 1 if is_image_job else 0
set_progress(status="running", message="Creating blueprint...", active=True, step=current_step, log_message="🎨 Quality Enhancer (Groq) creating blueprint...")
visual_blueprint, err = generate_visual_blueprint_with_groq(initial_prompt, GROQ_API_KEY, GROQ_MODEL)
if err:
set_progress(status="error", error=err, active=False, log_message=f"🛑 Groq Enhancer Failure: {err}"); return jsonify({"error": err}), 500
set_progress(log_message=f"✨ Enhanced Blueprint: \"{visual_blueprint[:80]}...\"")
### --- यह हिस्सा बदला गया है --- ###
# हमने यहाँ से "Prompt Distiller" को हटा दिया है।
# अब हम विस्तृत प्रॉम्प्ट को सीधे उपयोग करेंगे।
negative_keywords = "blurry, deformed, ugly, bad anatomy, watermark, noise, grain, low quality, distortion, glitch, pixelated, artifacts"
# अब हम `visual_blueprint` का सीधा उपयोग कर रहे हैं।
final_prompt = f"{visual_blueprint}, {negative_keywords}"
print(f"🚀 Final Prompt for Bytez (Model: {bytez_model}): {final_prompt}")
job_thread = threading.Thread(target=generate_video_job, args=(final_prompt, num_clips, bytez_model), daemon=True)
job_thread.start()
return jsonify({"ok": True, "message": "Job started."})
@app.route("/progress", methods=["GET"])
def get_progress_endpoint():
return jsonify(get_progress_copy())
@app.route(f"/{OUTPUT_FOLDER}/<path:filename>")
def serve_output_file(filename):
if ".." in filename or filename.startswith("/"): abort(400)
return send_from_directory(OUTPUT_FOLDER, filename)
@app.route('/manifest.json')
def serve_manifest():
return send_from_directory('static', 'manifest.json')
@app.route('/service-worker.js')
def serve_sw():
return send_from_directory('static', 'service-worker.js')
# ---------- RUN ----------
print(f"Starting {len(API_KEYS)} worker threads for this process...")
for api_key in API_KEYS:
worker_thread = threading.Thread(target=worker, args=(api_key,), daemon=True)
worker_thread.start()