bytez-spark / app.py
diwash-barla's picture
Update app.py
965b251 verified
raw
history blame
18.6 kB
import os
import json
import uuid
import threading
import requests
import subprocess
import base64
from flask import Flask, render_template, request, jsonify, send_from_directory, abort # 'render_template' जोड़ा गया और 'render_template_string' हटाया गया
import queue
from groq import Groq
import itertools
# ---------- CONFIG ----------
API_KEYS = [os.environ.get(f"btz{i}") for i in range(1, 100) if os.environ.get(f"btz{i}")]
if not API_KEYS: raise KeyError("No 'btzN' environment variables found (e.g., btz1, btz2). Please set them in Hugging Face secrets.")
GROQ_API_KEY = os.environ.get("GROQ_KEY")
if not GROQ_API_KEY: raise KeyError("GROQ_KEY environment variable not found.")
GEMINI_API_KEYS = [os.environ.get(f"gmni{i}") for i in range(1, 100) if os.environ.get(f"gmni{i}")]
if not GEMINI_API_KEYS: raise KeyError("No 'gmniN' environment variables found (e.g., gmni1, gmni2). Please set them in Hugging Face secrets.")
gemini_key_cycler = itertools.cycle(GEMINI_API_KEYS)
DEFAULT_BYTEZ_MODEL = "ali-vilab/text-to-video-ms-1.7b"
# ---------- MODEL HUNTER FUNCTIONS ----------
def find_best_groq_model(api_key):
try:
print("🤖 Hunting for the best Groq model...")
client = Groq(api_key=api_key); models = client.models.list().data
available_ids = {model.id for model in models}
preferred_keywords = ["llama-3.1", "llama3", "gemma2", "mixtral", "gemma"]
for keyword in preferred_keywords:
for model_id in available_ids:
if keyword in model_id: print(f"🎯 Groq Target locked: {model_id}"); return model_id
for model_id in available_ids:
if all(k not in model_id for k in ["guard", "tts", "prompt"]): print(f"✅ Groq Fallback found: {model_id}"); return model_id
raise ValueError("No usable Groq models found.")
except Exception as e: print(f"🛑 Groq hunt failed: {e}. Using hardcoded fallback."); return "llama-3.1-8b-instant"
def find_best_gemini_vision_model(api_key):
try:
print("🤖 Hunting for the best Gemini Vision model...")
url = f"https://generativelanguage.googleapis.com/v1beta/models?key={api_key}"
response = requests.get(url, timeout=15); response.raise_for_status()
models = response.json().get("models", [])
vision_models = [m["name"] for m in models if any("generateContent" in s for s in m.get("supportedGenerationMethods", [])) and ("vision" in m["name"] or "flash" in m["name"] or "1.5" in m["name"])]
preferred_models = ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-pro-vision"]
for preferred in preferred_models:
for model_path in vision_models:
if preferred in model_path: model_name = model_path.split('/')[-1]; print(f"🎯 Gemini Target locked: {model_name}"); return model_name
if vision_models: model_name = vision_models[0].split('/')[-1]; print(f"✅ Gemini Fallback found: {model_name}"); return model_name
raise ValueError("No usable Gemini Vision models found.")
except Exception as e: print(f"🛑 Gemini hunt failed: {e}. Using hardcoded fallback."); return "gemini-1.5-flash-latest"
# ---------- INITIALIZATION ----------
GROQ_MODEL = find_best_groq_model(GROQ_API_KEY)
GEMINI_VISION_MODEL = find_best_gemini_vision_model(GEMINI_API_KEYS[0])
print(f"✅ Loaded {len(API_KEYS)} Bytez keys and {len(GEMINI_API_KEYS)} Gemini keys. Using Groq: {GROQ_MODEL} and Gemini: {GEMINI_VISION_MODEL}")
OUTPUT_FOLDER = "output"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
os.makedirs("static", exist_ok=True)
# ---------- APP & STATE ----------
app = Flask(__name__)
progress = { "active": False, "step": 0, "total": 0, "status": "idle", "message": "", "error": None, "video_relpath": None, "live_log": [] }
job_queue = queue.Queue()
generated_clips_dict = {}
clips_lock = threading.Lock()
# ---------- HELPER FUNCTIONS ----------
def set_progress(log_message=None, **kwargs):
global progress
with threading.Lock():
progress.update(kwargs)
if log_message: progress["live_log"].append(f"> {log_message}");
if len(progress["live_log"]) > 20: progress["live_log"].pop(0)
def get_progress_copy():
with threading.Lock(): return progress.copy()
def get_prompt_from_gemini(image_data, user_text, mime_type, api_key):
print(f"🧠 Contacting Gemini (Creative Director) using key ...{api_key[-4:]}")
try:
encoded_image = base64.b64encode(image_data).decode('utf-8')
gemini_api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_VISION_MODEL}:generateContent?key={api_key}"
instruction_text = "You are a master AI Art Director. Your goal is to create a **vivid, rich, and visually descriptive** prompt for a text-to-video AI. Analyze the image and the user's text. Create a prompt that captures the essence, colors, and details of the scene. Focus entirely on what can be SEEN. Describe textures, lighting, and specific objects. The more visual detail, the better. Your output must be a powerful and inspiring guide for the artist."
parts = [{"text": instruction_text}]
if user_text: parts.append({"text": f"\nUser's instruction: '{user_text}'"})
parts.append({"inline_data": {"mime_type": mime_type, "data": encoded_image}})
payload = {"contents": [{"parts": parts}]}
headers = {'Content-Type': 'application/json'}
response = requests.post(gemini_api_url, headers=headers, json=payload, timeout=60); response.raise_for_status()
result = response.json()
if 'candidates' not in result or not result['candidates']:
if 'promptFeedback' in result and 'blockReason' in result['promptFeedback']: return None, f"Gemini API Blocked: {result['promptFeedback']['blockReason']}"
return None, "Gemini API returned an empty response."
generated_prompt = result['candidates'][0]['content']['parts'][0]['text'].strip()
print(f"🎬 Gemini's (Creative) prompt: {generated_prompt}")
return generated_prompt, None
except requests.exceptions.RequestException as e: return None, f"Gemini API Request Error: {e}"
except Exception as e: return None, f"An unexpected error occurred with Gemini: {e}"
def generate_visual_blueprint_with_groq(user_prompt, api_key, model_name):
print("🎨 Contacting Groq (Quality Enhancer)...")
try:
client = Groq(api_key=api_key)
system_prompt = "You are a prompt engineering expert, a 'Quality Enhancer' for a text-to-video AI. Your job is to take an incoming creative prompt and make it technically perfect for the AI artist. **Your primary rule: ADD powerful keywords.** Append terms like 'photorealistic, 4k, ultra realistic, sharp focus, hyper-detailed, vibrant colors, cinematic lighting' to the prompt. Refine the description to be even more visually precise. Do NOT simplify, ENHANCE. Example Input: 'A bush of pink roses.' Example Output: 'A dense bush of vibrant pink roses, hyper-detailed petals, sharp focus, cinematic lighting, photorealistic, 4k, ultra realistic.' Output ONLY the final, enhanced description."
response = client.chat.completions.create(model=model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}])
visual_blueprint = response.choices[0].message.content.strip()
print(f"✨ Groq's (Enhanced) Blueprint: {visual_blueprint}")
return visual_blueprint, None
except Exception as e: return None, f"Groq API Error: {e}"
def distill_prompt_for_short_video(enhanced_prompt, api_key, model_name):
print("🧪 Contacting Groq (Prompt Distiller)...")
try:
client = Groq(api_key=api_key)
system_prompt = """You are an expert video editor who specializes in ultra-short clips (3-4 seconds).
Your task is to take a long, hyper-detailed prompt and distill it into a concise, powerful version perfect for a brief video.
RULES:
1. Identify the single most important subject and a very brief, clear action.
2. Focus the entire prompt on ONE dynamic moment or a powerful, static scene. Do not describe multiple complex actions.
3. Condense descriptive keywords. Instead of 'photorealistic, 4k, ultra realistic, hyper-detailed', choose the best two, like 'cinematic, hyper-detailed'.
4. The final prompt must be significantly shorter but retain the core artistic vision.
Output ONLY the final, short, distilled prompt.
"""
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": enhanced_prompt}
]
)
distilled_prompt = response.choices[0].message.content.strip()
print(f"✨ Distiller's final prompt: {distilled_prompt}")
return distilled_prompt, None
except Exception as e:
print(f"⚠️ Groq Distiller failed: {e}. Falling back to original enhanced prompt.")
return enhanced_prompt, None
def generate_clip(prompt, idx, api_key, bytez_model):
from bytez import Bytez
sdk = Bytez(api_key)
model = sdk.model(bytez_model)
try:
out = model.run(prompt)
err = None
except Exception as e:
print(f"🛑 Error during model.run() with key ...{api_key[-4:]}: {e}")
out = None
err = str(e)
if err:
return None, f"Model Error (Key ...{api_key[-4:]}): {err}"
filename = f"clip_{idx}_{uuid.uuid4().hex}.mp4"
filepath = os.path.join(OUTPUT_FOLDER, filename)
try:
if isinstance(out, bytes):
with open(filepath, "wb") as f: f.write(out)
elif isinstance(out, str) and out.startswith('http'):
r = requests.get(out, timeout=300)
r.raise_for_status()
with open(filepath, "wb") as f: f.write(r.content)
else:
return None, f"Unexpected output type from model: {type(out)}"
except Exception as e:
return None, f"Failed to save or download the generated clip: {e}"
return filepath, None
def process_and_merge_clips(clip_files):
if not clip_files: return None
list_file = os.path.join(OUTPUT_FOLDER, f"clips_{uuid.uuid4().hex}.txt")
with open(list_file, "w") as f:
for c in clip_files: f.write(f"file '{os.path.abspath(c)}'\n")
raw_merge_path = os.path.join(OUTPUT_FOLDER, f"final_raw_{uuid.uuid4().hex}.mp4")
merge_cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", raw_merge_path]
try: subprocess.run(merge_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
finally:
if os.path.exists(list_file): os.remove(list_file)
cinematic_path = os.path.join(OUTPUT_FOLDER, f"final_cinematic_{uuid.uuid4().hex}.mp4")
vf_filters = "scale=1280:720:flags=lanczos,eq=contrast=1.1:brightness=0.05:saturation=1.15,unsharp=5:5:0.8:3:3:0.4"
process_cmd = ["ffmpeg", "-i", raw_merge_path, "-vf", vf_filters, "-c:v", "libx264", "-preset", "veryfast", "-crf", "23", "-c:a", "copy", cinematic_path]
try:
print("🔧 Attempting cinematic post-processing..."); subprocess.run(process_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("✅ Cinematic processing successful!")
if os.path.exists(raw_merge_path): os.remove(raw_merge_path)
return cinematic_path
except Exception as e: print(f"⚠️ Cinematic processing failed: {e}. Falling back to raw video."); return raw_merge_path
def worker(api_key):
while True:
try:
prompt, idx, num_clips_in_job, bytez_model = job_queue.get()
clip_path, err = generate_clip(prompt, idx, api_key, bytez_model)
with clips_lock:
if err: print(f"⚠️ Clip {idx + 1} failed: {err}. Skipping."); set_progress(log_message=f"⚠️ Artist #{idx + 1} failed. Skipping.")
else: generated_clips_dict[idx] = clip_path; set_progress(log_message=f"✅ Artist #{idx + 1} finished clip.")
successful_clips = len(generated_clips_dict); p_data = get_progress_copy()
base_step = 1 + (1 if p_data.get('is_image_job') else 0)
set_progress(step=base_step + successful_clips, message=f"Generated {successful_clips} of {num_clips_in_job} clips...")
finally: job_queue.task_done()
def generate_video_job(prompt, num_clips, bytez_model):
temp_clip_paths = []
try:
with clips_lock: generated_clips_dict.clear()
total_steps = get_progress_copy().get("total", num_clips + 2); base_step = total_steps - num_clips - 1
set_progress(active=True, step=base_step, message=f"Queueing {num_clips} clips...", log_message=f"🚀 Dispatching job to {num_clips} artists...")
for i in range(num_clips): job_queue.put((prompt, i, num_clips, bytez_model))
job_queue.join()
with clips_lock:
sorted_clips = sorted(generated_clips_dict.items())
clip_paths = [path for _, path in sorted_clips]; temp_clip_paths.extend(clip_paths)
successful_clips = len(clip_paths)
print(f"✅ Generation phase complete. {successful_clips} out of {num_clips} clips succeeded.")
set_progress(log_message=f"👍 {successful_clips} clips are ready.")
if successful_clips == 0: raise RuntimeError("All clips failed to generate. Cannot create a video.")
set_progress(step=total_steps - 1, message=f"Merging {successful_clips} successful clips...", log_message="🔧 Merging and post-processing...")
final_abs_path = process_and_merge_clips(clip_paths)
if not final_abs_path: raise RuntimeError("Failed to merge the generated clips.")
final_rel_path = os.path.relpath(final_abs_path, start=os.getcwd())
final_message = f"✅ Video ready! ({successful_clips}/{num_clips} clips succeeded)"
set_progress(step=total_steps, status="done", message=final_message, video_relpath=final_rel_path, log_message="🎉 Mission Accomplished!")
except Exception as e: set_progress(status="error", error=str(e), message="Generation failed.", log_message=f"🛑 Mission Failure: {e}")
finally:
set_progress(active=False)
for clip in temp_clip_paths:
if os.path.exists(clip): os.remove(clip)
# ---------- FLASK ROUTES ----------
@app.route("/", methods=["GET"])
def home():
# अब यह 'templates' फोल्डर से 'index.html' फाइल को प्रस्तुत करेगा।
return render_template("index.html")
@app.route("/start", methods=["POST"])
def start():
set_progress(live_log=[])
if get_progress_copy().get("active", False): return jsonify({"error": "A video is already being generated. Please wait."}), 429
user_prompt = request.form.get("prompt", "").strip()
image_file = request.files.get("image")
num_clips = max(1, min(int(request.form.get("num_clips", 3)), 20))
style = request.form.get("style", "none")
bytez_model = request.form.get("bytez_model", "").strip()
if not bytez_model: bytez_model = DEFAULT_BYTEZ_MODEL
if not user_prompt and not image_file: return jsonify({"error": "Prompt or image cannot be empty."}), 400
initial_prompt = user_prompt
is_image_job = bool(image_file)
total_steps = num_clips + 2 + (1 if is_image_job else 0)
set_progress(is_image_job=is_image_job, total=total_steps)
if is_image_job:
try:
image_data = image_file.read()
mime_type = image_file.mimetype
set_progress(status="running", message="Initializing...", error=None, active=True, step=0, log_message="🧠 Director (Gemini) analyzing image...")
selected_gemini_key = next(gemini_key_cycler)
gemini_prompt, err = get_prompt_from_gemini(image_data, user_prompt, mime_type, selected_gemini_key)
if err:
set_progress(status="error", error=err, active=False, log_message=f"🛑 Gemini Failure: {err}"); return jsonify({"error": err}), 500
initial_prompt = gemini_prompt
set_progress(log_message=f"🎬 Gemini's Idea: \"{initial_prompt[:80]}...\"")
except Exception as e:
err_msg = f"Failed to process image: {e}"; set_progress(status="error", error=err_msg, active=False, log_message=f"🛑 Image Error: {e}"); return jsonify({"error": err_msg}), 500
current_step = 1 if is_image_job else 0
set_progress(status="running", message="Creating blueprint...", active=True, step=current_step, log_message="🎨 Quality Enhancer (Groq) creating blueprint...")
visual_blueprint, err = generate_visual_blueprint_with_groq(initial_prompt, GROQ_API_KEY, GROQ_MODEL)
if err:
set_progress(status="error", error=err, active=False, log_message=f"🛑 Groq Enhancer Failure: {err}"); return jsonify({"error": err}), 500
set_progress(log_message=f"✨ Enhanced Blueprint: \"{visual_blueprint[:80]}...\"")
set_progress(log_message="🧪 Distilling prompt for short video clip...")
distilled_prompt, err = distill_prompt_for_short_video(visual_blueprint, GROQ_API_KEY, GROQ_MODEL)
set_progress(log_message=f"🎯 Final Distilled Prompt: \"{distilled_prompt[:80]}...\"")
negative_keywords = "blurry, deformed, ugly, bad anatomy, watermark, noise, grain, low quality, distortion, glitch, pixelated, artifacts"
final_prompt = f"{distilled_prompt}, {negative_keywords}"
print(f"🚀 Final Prompt for Bytez (Model: {bytez_model}): {final_prompt}")
job_thread = threading.Thread(target=generate_video_job, args=(final_prompt, num_clips, bytez_model), daemon=True)
job_thread.start()
return jsonify({"ok": True, "message": "Job started."})
@app.route("/progress", methods=["GET"])
def get_progress_endpoint():
return jsonify(get_progress_copy())
@app.route(f"/{OUTPUT_FOLDER}/<path:filename>")
def serve_output_file(filename):
if ".." in filename or filename.startswith("/"): abort(400)
return send_from_directory(OUTPUT_FOLDER, filename)
@app.route('/manifest.json')
def serve_manifest():
return send_from_directory('static', 'manifest.json')
@app.route('/service-worker.js')
def serve_sw():
return send_from_directory('static', 'service-worker.js')
# ---------- RUN ----------
print(f"Starting {len(API_KEYS)} worker threads for this process...")
for api_key in API_KEYS:
worker_thread = threading.Thread(target=worker, args=(api_key,), daemon=True)
worker_thread.start()