multimodalart's picture
Update app.py
134049c verified
import subprocess
subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import sys
import torch
import datetime
import numpy as np
from PIL import Image
import imageio
import shutil
import requests
import base64
import io
import spaces
# --- Part 1: Auto-Setup (Clone Repo & Download Weights) ---
REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git"
REPO_DIR = os.path.abspath("HunyuanVideo-1.5")
MODEL_DIR = os.path.abspath("ckpts")
# Repositories
HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev"
# Configuration
TRANSFORMER_VERSION = "480p_i2v_distilled"
DTYPE = torch.bfloat16
ENABLE_OFFLOADING = False
def setup_environment():
print("=" * 50)
print("Checking Environment & Dependencies...")
# 1. Clone Code Repository
if not os.path.exists(REPO_DIR):
print(f"Cloning repository to {REPO_DIR}...")
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
# 2. Add Repo to Python Path
if REPO_DIR not in sys.path:
sys.path.insert(0, REPO_DIR)
# 3. Download Main Weights (Transformer, VAE, Scheduler)
os.makedirs(MODEL_DIR, exist_ok=True)
target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
if not os.path.exists(target_transformer):
print(f"Downloading Main Weights from {HF_MAIN_REPO}...")
try:
from huggingface_hub import snapshot_download
allow_patterns = [
f"transformer/{TRANSFORMER_VERSION}/*",
"vae/*",
"scheduler/*",
"tokenizer/*"
]
snapshot_download(
repo_id=HF_MAIN_REPO,
local_dir=MODEL_DIR,
allow_patterns=allow_patterns,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading main weights: {e}")
sys.exit(1)
# 4. Download LLM Text Encoder (Qwen)
llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm")
if not os.path.exists(llm_target) or not os.listdir(llm_target):
print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=HF_LLM_REPO,
local_dir=llm_target,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading LLM: {e}")
# 5. Download Vision Encoder (SigLIP)
vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
if not os.path.exists(vision_target) or not os.listdir(vision_target):
print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=HF_VISION_REPO,
local_dir=vision_target,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading Vision Encoder: {e}")
# 6. Download & Restructure Glyph Weights
glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
if not os.path.exists(glyph_ckpt_target):
print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
try:
from huggingface_hub import snapshot_download
glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False)
os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
# Move Assets
src_assets = os.path.join(glyph_temp, "assets")
if os.path.exists(src_assets):
for f in os.listdir(src_assets):
shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
# Move Model
src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
if os.path.exists(src_bin):
shutil.move(src_bin, glyph_ckpt_target)
else:
src_safe = os.path.join(glyph_temp, "model.safetensors")
if os.path.exists(src_safe):
shutil.move(src_safe, glyph_ckpt_target)
shutil.rmtree(glyph_temp, ignore_errors=True)
except Exception as e:
print(f"Error setting up Glyph weights: {e}")
print("Environment Ready.")
print("=" * 50)
setup_environment()
# --- Part 2: Imports & Patching ---
try:
import hyvideo.commons
import hyvideo.pipelines.hunyuan_video_pipeline
from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
from hyvideo.commons.infer_state import initialize_infer_state
# Import the specific I2V System Prompt from the repo
from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt
except ImportError as e:
print(f"CRITICAL ERROR: {e}")
sys.exit(1)
import gradio as gr
def dummy_get_gpu_memory(device=None):
return 80 * 1024 * 1024 * 1024
print("🛠️ Applying ZeroGPU Monkey Patch...")
hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory
hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory
# --- Part 3: Prompt Rewrite Logic (External API) ---
def encode_image_to_base64(pil_image):
buffered = io.BytesIO()
pil_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/jpeg;base64,{img_str}"
def rewrite_prompt_external(user_prompt, pil_image):
"""Calls HF Router API to rewrite prompt using Qwen2.5-VL"""
api_key = os.environ.get("HF_TOKEN")
if not api_key:
print("⚠️ No HF_TOKEN found. Skipping rewrite.")
return user_prompt
print("🧠 Rewriting prompt via API...")
API_URL = "https://router.huggingface.co/v1/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
# Combine the official Hunyuan System Prompt with the User Input
# The system prompt string contains a {} placeholder for the user input
full_instruction = i2v_rewrite_system_prompt.format(user_prompt)
base64_img = encode_image_to_base64(pil_image)
payload = {
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": full_instruction
},
{
"type": "image_url",
"image_url": {
"url": base64_img
}
}
]
}
],
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"max_tokens": 512,
"temperature": 0.7
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
rewritten = data["choices"][0]["message"]["content"]
print(f"✅ Rewritten: {rewritten[:50]}...")
return rewritten
except Exception as e:
print(f"❌ Rewrite failed: {e}")
return user_prompt
# --- Part 4: Model Initialization (CPU) ---
class ArgsNamespace:
def __init__(self):
self.sage_blocks_range = "0-53"
self.no_cache_block_id = "0-0"
self.use_sageattn = False
self.enable_torch_compile = False
self.enable_cache = False
self.cache_type = "deepcache"
self.cache_start_step = 11
self.cache_end_step = 45
self.total_steps = 50
self.cache_step_interval = 4
initialize_infer_state(ArgsNamespace())
print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...")
try:
pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
pretrained_model_name_or_path=MODEL_DIR,
transformer_version=TRANSFORMER_VERSION,
enable_offloading=ENABLE_OFFLOADING,
enable_group_offloading=ENABLE_OFFLOADING,
transformer_dtype=DTYPE,
device=torch.device('cpu')
)
pipe.to('cuda')
print("✅ Model loaded into CPU RAM.")
except Exception as e:
print(f"❌ Failed to load model: {e}")
sys.exit(1)
def save_video_tensor(video_tensor, path, fps=24):
if isinstance(video_tensor, list): video_tensor = video_tensor[0]
if video_tensor.ndim == 5: video_tensor = video_tensor[0]
vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8)
vid = vid.permute(1, 2, 3, 0).cpu().numpy()
imageio.mimwrite(path, vid, fps=fps)
# --- Part 5: Inference ---
@spaces.GPU(duration=120)
def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)):
if pipe is None: raise gr.Error("Pipeline not initialized!")
if input_image is None: raise gr.Error("Reference image required.")
# Process Input Image
if isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image).convert("RGB")
else:
pil_image = input_image.convert("RGB")
# 1. Prompt Rewrite (if enabled)
actual_prompt = prompt
if do_rewrite:
actual_prompt = rewrite_prompt_external(prompt, pil_image)
# 2. Setup Generator
if seed == -1: seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator(device="cpu").manual_seed(int(seed))
print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}")
try:
pipe.execution_device = torch.device("cuda")
output = pipe(
prompt=actual_prompt,
height=480, width=854, aspect_ratio="16:9",
video_length=int(length),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
flow_shift=float(shift),
reference_image=pil_image,
seed=int(seed),
generator=generator,
output_type="pt",
enable_sr=False,
return_dict=True
)
except Exception as e:
print(f"Error: {e}")
raise gr.Error(f"Inference Failed: {e}")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True)
output_path = f"outputs/gen_{timestamp}.mp4"
save_video_tensor(output.videos, output_path)
return output_path, actual_prompt
# --- Part 6: UI ---
css = '''.gradio-container .app { max-width: 900px !important; margin: 0 auto; }
.dark .progress-text{color: white !important}'''
def create_ui():
with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo:
gr.Markdown(f"# 🎬 HunyuanVideo 1.5 I2V 480p distilled demo")
gr.Markdown(f"This is a demo for HunyuanVideo 1.5 I2v {TRANSFORMER_VERSION}, released together with a collection of 10 other checkpoints (text-to-video, 720p, upscalers). Check out the [HunyuanVideo-1.5 model page](https://huggingface.co/tencent/HunyuanVideo-1.5) for more")
with gr.Row():
with gr.Column():
img = gr.Image(label="Reference", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2)
rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Strongly Recommended)", value=True)
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
steps = gr.Slider(2, 50, value=6, step=1, label="Steps")
guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance")
with gr.Row():
shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift")
length = gr.Slider(1, 129, value=61, step=4, label="Length")
seed = gr.Number(value=-1, label="Seed", precision=0, info="-1 is a random seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Video(label="Result", autoplay=True)
final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False)
btn.click(
generate,
inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk],
outputs=[out, final_prompt_box]
)
return demo
if __name__ == "__main__":
ui = create_ui()
ui.queue().launch(server_name="0.0.0.0", share=True, css=css)