Spaces:
Running
on
Zero
Running
on
Zero
| import subprocess | |
| subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| import os | |
| import sys | |
| import torch | |
| import datetime | |
| import numpy as np | |
| from PIL import Image | |
| import imageio | |
| import shutil | |
| import requests | |
| import base64 | |
| import io | |
| import spaces | |
| # --- Part 1: Auto-Setup (Clone Repo & Download Weights) --- | |
| REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git" | |
| REPO_DIR = os.path.abspath("HunyuanVideo-1.5") | |
| MODEL_DIR = os.path.abspath("ckpts") | |
| # Repositories | |
| HF_MAIN_REPO = "tencent/HunyuanVideo-1.5" | |
| HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small" | |
| HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev" | |
| # Configuration | |
| TRANSFORMER_VERSION = "480p_i2v_distilled" | |
| DTYPE = torch.bfloat16 | |
| ENABLE_OFFLOADING = False | |
| def setup_environment(): | |
| print("=" * 50) | |
| print("Checking Environment & Dependencies...") | |
| # 1. Clone Code Repository | |
| if not os.path.exists(REPO_DIR): | |
| print(f"Cloning repository to {REPO_DIR}...") | |
| subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) | |
| # 2. Add Repo to Python Path | |
| if REPO_DIR not in sys.path: | |
| sys.path.insert(0, REPO_DIR) | |
| # 3. Download Main Weights (Transformer, VAE, Scheduler) | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION) | |
| if not os.path.exists(target_transformer): | |
| print(f"Downloading Main Weights from {HF_MAIN_REPO}...") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| allow_patterns = [ | |
| f"transformer/{TRANSFORMER_VERSION}/*", | |
| "vae/*", | |
| "scheduler/*", | |
| "tokenizer/*" | |
| ] | |
| snapshot_download( | |
| repo_id=HF_MAIN_REPO, | |
| local_dir=MODEL_DIR, | |
| allow_patterns=allow_patterns, | |
| local_dir_use_symlinks=False | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading main weights: {e}") | |
| sys.exit(1) | |
| # 4. Download LLM Text Encoder (Qwen) | |
| llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm") | |
| if not os.path.exists(llm_target) or not os.listdir(llm_target): | |
| print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=HF_LLM_REPO, | |
| local_dir=llm_target, | |
| local_dir_use_symlinks=False | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading LLM: {e}") | |
| # 5. Download Vision Encoder (SigLIP) | |
| vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip") | |
| if not os.path.exists(vision_target) or not os.listdir(vision_target): | |
| print(f"Downloading Vision Encoder from {HF_VISION_REPO}...") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=HF_VISION_REPO, | |
| local_dir=vision_target, | |
| local_dir_use_symlinks=False | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading Vision Encoder: {e}") | |
| # 6. Download & Restructure Glyph Weights | |
| glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2") | |
| glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt") | |
| if not os.path.exists(glyph_ckpt_target): | |
| print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| glyph_temp = os.path.join(MODEL_DIR, "glyph_temp") | |
| snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False) | |
| os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True) | |
| os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True) | |
| # Move Assets | |
| src_assets = os.path.join(glyph_temp, "assets") | |
| if os.path.exists(src_assets): | |
| for f in os.listdir(src_assets): | |
| shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f)) | |
| # Move Model | |
| src_bin = os.path.join(glyph_temp, "pytorch_model.bin") | |
| if os.path.exists(src_bin): | |
| shutil.move(src_bin, glyph_ckpt_target) | |
| else: | |
| src_safe = os.path.join(glyph_temp, "model.safetensors") | |
| if os.path.exists(src_safe): | |
| shutil.move(src_safe, glyph_ckpt_target) | |
| shutil.rmtree(glyph_temp, ignore_errors=True) | |
| except Exception as e: | |
| print(f"Error setting up Glyph weights: {e}") | |
| print("Environment Ready.") | |
| print("=" * 50) | |
| setup_environment() | |
| # --- Part 2: Imports & Patching --- | |
| try: | |
| import hyvideo.commons | |
| import hyvideo.pipelines.hunyuan_video_pipeline | |
| from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline | |
| from hyvideo.commons.infer_state import initialize_infer_state | |
| # Import the specific I2V System Prompt from the repo | |
| from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt | |
| except ImportError as e: | |
| print(f"CRITICAL ERROR: {e}") | |
| sys.exit(1) | |
| import gradio as gr | |
| def dummy_get_gpu_memory(device=None): | |
| return 80 * 1024 * 1024 * 1024 | |
| print("🛠️ Applying ZeroGPU Monkey Patch...") | |
| hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory | |
| hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory | |
| # --- Part 3: Prompt Rewrite Logic (External API) --- | |
| def encode_image_to_base64(pil_image): | |
| buffered = io.BytesIO() | |
| pil_image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/jpeg;base64,{img_str}" | |
| def rewrite_prompt_external(user_prompt, pil_image): | |
| """Calls HF Router API to rewrite prompt using Qwen2.5-VL""" | |
| api_key = os.environ.get("HF_TOKEN") | |
| if not api_key: | |
| print("⚠️ No HF_TOKEN found. Skipping rewrite.") | |
| return user_prompt | |
| print("🧠 Rewriting prompt via API...") | |
| API_URL = "https://router.huggingface.co/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| # Combine the official Hunyuan System Prompt with the User Input | |
| # The system prompt string contains a {} placeholder for the user input | |
| full_instruction = i2v_rewrite_system_prompt.format(user_prompt) | |
| base64_img = encode_image_to_base64(pil_image) | |
| payload = { | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": full_instruction | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": base64_img | |
| } | |
| } | |
| ] | |
| } | |
| ], | |
| "model": "Qwen/Qwen2.5-VL-7B-Instruct", | |
| "max_tokens": 512, | |
| "temperature": 0.7 | |
| } | |
| try: | |
| response = requests.post(API_URL, headers=headers, json=payload, timeout=30) | |
| response.raise_for_status() | |
| data = response.json() | |
| rewritten = data["choices"][0]["message"]["content"] | |
| print(f"✅ Rewritten: {rewritten[:50]}...") | |
| return rewritten | |
| except Exception as e: | |
| print(f"❌ Rewrite failed: {e}") | |
| return user_prompt | |
| # --- Part 4: Model Initialization (CPU) --- | |
| class ArgsNamespace: | |
| def __init__(self): | |
| self.sage_blocks_range = "0-53" | |
| self.no_cache_block_id = "0-0" | |
| self.use_sageattn = False | |
| self.enable_torch_compile = False | |
| self.enable_cache = False | |
| self.cache_type = "deepcache" | |
| self.cache_start_step = 11 | |
| self.cache_end_step = 45 | |
| self.total_steps = 50 | |
| self.cache_step_interval = 4 | |
| initialize_infer_state(ArgsNamespace()) | |
| print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...") | |
| try: | |
| pipe = HunyuanVideo_1_5_Pipeline.create_pipeline( | |
| pretrained_model_name_or_path=MODEL_DIR, | |
| transformer_version=TRANSFORMER_VERSION, | |
| enable_offloading=ENABLE_OFFLOADING, | |
| enable_group_offloading=ENABLE_OFFLOADING, | |
| transformer_dtype=DTYPE, | |
| device=torch.device('cpu') | |
| ) | |
| pipe.to('cuda') | |
| print("✅ Model loaded into CPU RAM.") | |
| except Exception as e: | |
| print(f"❌ Failed to load model: {e}") | |
| sys.exit(1) | |
| def save_video_tensor(video_tensor, path, fps=24): | |
| if isinstance(video_tensor, list): video_tensor = video_tensor[0] | |
| if video_tensor.ndim == 5: video_tensor = video_tensor[0] | |
| vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8) | |
| vid = vid.permute(1, 2, 3, 0).cpu().numpy() | |
| imageio.mimwrite(path, vid, fps=fps) | |
| # --- Part 5: Inference --- | |
| def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)): | |
| if pipe is None: raise gr.Error("Pipeline not initialized!") | |
| if input_image is None: raise gr.Error("Reference image required.") | |
| # Process Input Image | |
| if isinstance(input_image, np.ndarray): | |
| pil_image = Image.fromarray(input_image).convert("RGB") | |
| else: | |
| pil_image = input_image.convert("RGB") | |
| # 1. Prompt Rewrite (if enabled) | |
| actual_prompt = prompt | |
| if do_rewrite: | |
| actual_prompt = rewrite_prompt_external(prompt, pil_image) | |
| # 2. Setup Generator | |
| if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() | |
| generator = torch.Generator(device="cpu").manual_seed(int(seed)) | |
| print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}") | |
| try: | |
| pipe.execution_device = torch.device("cuda") | |
| output = pipe( | |
| prompt=actual_prompt, | |
| height=480, width=854, aspect_ratio="16:9", | |
| video_length=int(length), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance), | |
| flow_shift=float(shift), | |
| reference_image=pil_image, | |
| seed=int(seed), | |
| generator=generator, | |
| output_type="pt", | |
| enable_sr=False, | |
| return_dict=True | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise gr.Error(f"Inference Failed: {e}") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("outputs", exist_ok=True) | |
| output_path = f"outputs/gen_{timestamp}.mp4" | |
| save_video_tensor(output.videos, output_path) | |
| return output_path, actual_prompt | |
| # --- Part 6: UI --- | |
| css = '''.gradio-container .app { max-width: 900px !important; margin: 0 auto; } | |
| .dark .progress-text{color: white !important}''' | |
| def create_ui(): | |
| with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo: | |
| gr.Markdown(f"# 🎬 HunyuanVideo 1.5 I2V 480p distilled demo") | |
| gr.Markdown(f"This is a demo for HunyuanVideo 1.5 I2v {TRANSFORMER_VERSION}, released together with a collection of 10 other checkpoints (text-to-video, 720p, upscalers). Check out the [HunyuanVideo-1.5 model page](https://huggingface.co/tencent/HunyuanVideo-1.5) for more") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image(label="Reference", type="pil") | |
| prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2) | |
| rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Strongly Recommended)", value=True) | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Row(): | |
| steps = gr.Slider(2, 50, value=6, step=1, label="Steps") | |
| guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance") | |
| with gr.Row(): | |
| shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift") | |
| length = gr.Slider(1, 129, value=61, step=4, label="Length") | |
| seed = gr.Number(value=-1, label="Seed", precision=0, info="-1 is a random seed") | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| out = gr.Video(label="Result", autoplay=True) | |
| final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False) | |
| btn.click( | |
| generate, | |
| inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk], | |
| outputs=[out, final_prompt_box] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| ui = create_ui() | |
| ui.queue().launch(server_name="0.0.0.0", share=True, css=css) |