import subprocess subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import os import sys import torch import datetime import numpy as np from PIL import Image import imageio import shutil import requests import base64 import io import spaces # --- Part 1: Auto-Setup (Clone Repo & Download Weights) --- REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git" REPO_DIR = os.path.abspath("HunyuanVideo-1.5") MODEL_DIR = os.path.abspath("ckpts") # Repositories HF_MAIN_REPO = "tencent/HunyuanVideo-1.5" HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small" HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct" HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev" # Configuration TRANSFORMER_VERSION = "480p_i2v_distilled" DTYPE = torch.bfloat16 ENABLE_OFFLOADING = False def setup_environment(): print("=" * 50) print("Checking Environment & Dependencies...") # 1. Clone Code Repository if not os.path.exists(REPO_DIR): print(f"Cloning repository to {REPO_DIR}...") subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) # 2. Add Repo to Python Path if REPO_DIR not in sys.path: sys.path.insert(0, REPO_DIR) # 3. Download Main Weights (Transformer, VAE, Scheduler) os.makedirs(MODEL_DIR, exist_ok=True) target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION) if not os.path.exists(target_transformer): print(f"Downloading Main Weights from {HF_MAIN_REPO}...") try: from huggingface_hub import snapshot_download allow_patterns = [ f"transformer/{TRANSFORMER_VERSION}/*", "vae/*", "scheduler/*", "tokenizer/*" ] snapshot_download( repo_id=HF_MAIN_REPO, local_dir=MODEL_DIR, allow_patterns=allow_patterns, local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading main weights: {e}") sys.exit(1) # 4. Download LLM Text Encoder (Qwen) llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm") if not os.path.exists(llm_target) or not os.listdir(llm_target): print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...") try: from huggingface_hub import snapshot_download snapshot_download( repo_id=HF_LLM_REPO, local_dir=llm_target, local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading LLM: {e}") # 5. Download Vision Encoder (SigLIP) vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip") if not os.path.exists(vision_target) or not os.listdir(vision_target): print(f"Downloading Vision Encoder from {HF_VISION_REPO}...") try: from huggingface_hub import snapshot_download snapshot_download( repo_id=HF_VISION_REPO, local_dir=vision_target, local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading Vision Encoder: {e}") # 6. Download & Restructure Glyph Weights glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2") glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt") if not os.path.exists(glyph_ckpt_target): print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...") try: from huggingface_hub import snapshot_download glyph_temp = os.path.join(MODEL_DIR, "glyph_temp") snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False) os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True) os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True) # Move Assets src_assets = os.path.join(glyph_temp, "assets") if os.path.exists(src_assets): for f in os.listdir(src_assets): shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f)) # Move Model src_bin = os.path.join(glyph_temp, "pytorch_model.bin") if os.path.exists(src_bin): shutil.move(src_bin, glyph_ckpt_target) else: src_safe = os.path.join(glyph_temp, "model.safetensors") if os.path.exists(src_safe): shutil.move(src_safe, glyph_ckpt_target) shutil.rmtree(glyph_temp, ignore_errors=True) except Exception as e: print(f"Error setting up Glyph weights: {e}") print("Environment Ready.") print("=" * 50) setup_environment() # --- Part 2: Imports & Patching --- try: import hyvideo.commons import hyvideo.pipelines.hunyuan_video_pipeline from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline from hyvideo.commons.infer_state import initialize_infer_state # Import the specific I2V System Prompt from the repo from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt except ImportError as e: print(f"CRITICAL ERROR: {e}") sys.exit(1) import gradio as gr def dummy_get_gpu_memory(device=None): return 80 * 1024 * 1024 * 1024 print("🛠️ Applying ZeroGPU Monkey Patch...") hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory # --- Part 3: Prompt Rewrite Logic (External API) --- def encode_image_to_base64(pil_image): buffered = io.BytesIO() pil_image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return f"data:image/jpeg;base64,{img_str}" def rewrite_prompt_external(user_prompt, pil_image): """Calls HF Router API to rewrite prompt using Qwen2.5-VL""" api_key = os.environ.get("HF_TOKEN") if not api_key: print("⚠️ No HF_TOKEN found. Skipping rewrite.") return user_prompt print("🧠 Rewriting prompt via API...") API_URL = "https://router.huggingface.co/v1/chat/completions" headers = {"Authorization": f"Bearer {api_key}"} # Combine the official Hunyuan System Prompt with the User Input # The system prompt string contains a {} placeholder for the user input full_instruction = i2v_rewrite_system_prompt.format(user_prompt) base64_img = encode_image_to_base64(pil_image) payload = { "messages": [ { "role": "user", "content": [ { "type": "text", "text": full_instruction }, { "type": "image_url", "image_url": { "url": base64_img } } ] } ], "model": "Qwen/Qwen2.5-VL-7B-Instruct", "max_tokens": 512, "temperature": 0.7 } try: response = requests.post(API_URL, headers=headers, json=payload, timeout=30) response.raise_for_status() data = response.json() rewritten = data["choices"][0]["message"]["content"] print(f"✅ Rewritten: {rewritten[:50]}...") return rewritten except Exception as e: print(f"❌ Rewrite failed: {e}") return user_prompt # --- Part 4: Model Initialization (CPU) --- class ArgsNamespace: def __init__(self): self.sage_blocks_range = "0-53" self.no_cache_block_id = "0-0" self.use_sageattn = False self.enable_torch_compile = False self.enable_cache = False self.cache_type = "deepcache" self.cache_start_step = 11 self.cache_end_step = 45 self.total_steps = 50 self.cache_step_interval = 4 initialize_infer_state(ArgsNamespace()) print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...") try: pipe = HunyuanVideo_1_5_Pipeline.create_pipeline( pretrained_model_name_or_path=MODEL_DIR, transformer_version=TRANSFORMER_VERSION, enable_offloading=ENABLE_OFFLOADING, enable_group_offloading=ENABLE_OFFLOADING, transformer_dtype=DTYPE, device=torch.device('cpu') ) pipe.to('cuda') print("✅ Model loaded into CPU RAM.") except Exception as e: print(f"❌ Failed to load model: {e}") sys.exit(1) def save_video_tensor(video_tensor, path, fps=24): if isinstance(video_tensor, list): video_tensor = video_tensor[0] if video_tensor.ndim == 5: video_tensor = video_tensor[0] vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8) vid = vid.permute(1, 2, 3, 0).cpu().numpy() imageio.mimwrite(path, vid, fps=fps) # --- Part 5: Inference --- @spaces.GPU(duration=120) def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)): if pipe is None: raise gr.Error("Pipeline not initialized!") if input_image is None: raise gr.Error("Reference image required.") # Process Input Image if isinstance(input_image, np.ndarray): pil_image = Image.fromarray(input_image).convert("RGB") else: pil_image = input_image.convert("RGB") # 1. Prompt Rewrite (if enabled) actual_prompt = prompt if do_rewrite: actual_prompt = rewrite_prompt_external(prompt, pil_image) # 2. Setup Generator if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() generator = torch.Generator(device="cpu").manual_seed(int(seed)) print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}") try: pipe.execution_device = torch.device("cuda") output = pipe( prompt=actual_prompt, height=480, width=854, aspect_ratio="16:9", video_length=int(length), num_inference_steps=int(steps), guidance_scale=float(guidance), flow_shift=float(shift), reference_image=pil_image, seed=int(seed), generator=generator, output_type="pt", enable_sr=False, return_dict=True ) except Exception as e: print(f"Error: {e}") raise gr.Error(f"Inference Failed: {e}") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs("outputs", exist_ok=True) output_path = f"outputs/gen_{timestamp}.mp4" save_video_tensor(output.videos, output_path) return output_path, actual_prompt # --- Part 6: UI --- css = '''.gradio-container .app { max-width: 900px !important; margin: 0 auto; } .dark .progress-text{color: white !important}''' def create_ui(): with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo: gr.Markdown(f"# 🎬 HunyuanVideo 1.5 I2V 480p distilled demo") gr.Markdown(f"This is a demo for HunyuanVideo 1.5 I2v {TRANSFORMER_VERSION}, released together with a collection of 10 other checkpoints (text-to-video, 720p, upscalers). Check out the [HunyuanVideo-1.5 model page](https://huggingface.co/tencent/HunyuanVideo-1.5) for more") with gr.Row(): with gr.Column(): img = gr.Image(label="Reference", type="pil") prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2) rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Strongly Recommended)", value=True) with gr.Accordion("Advanced Options", open=False): with gr.Row(): steps = gr.Slider(2, 50, value=6, step=1, label="Steps") guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance") with gr.Row(): shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift") length = gr.Slider(1, 129, value=61, step=4, label="Length") seed = gr.Number(value=-1, label="Seed", precision=0, info="-1 is a random seed") btn = gr.Button("Generate", variant="primary") with gr.Column(): out = gr.Video(label="Result", autoplay=True) final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False) btn.click( generate, inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk], outputs=[out, final_prompt_box] ) return demo if __name__ == "__main__": ui = create_ui() ui.queue().launch(server_name="0.0.0.0", share=True, css=css)