trellis_copy / app.py
ckc99u's picture
Update app.py
8e1c768 verified
import gradio as gr
import spaces
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
import tempfile
import numpy as np
import torch
import trimesh
import imageio
from typing import List, Tuple
from PIL import Image
from easydict import EasyDict as edict
# Add missing imports for MagicArticulate API
from gradio_client import Client, handle_file
from gradio_client.exceptions import AppError
# TRELLIS imports
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
MAGIC_ARTICULATE_URL = "https://f3fe9e3f800481d9bd.gradio.live"
# Configuration
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
# Initialize TRELLIS pipeline globally
pipeline = None
def init_pipeline():
"""Initialize TRELLIS pipeline on first load"""
global pipeline
if pipeline is None:
print("🔄 Loading TRELLIS pipeline...")
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
# Preload rembg
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
except:
pass
print("✅ TRELLIS pipeline loaded!")
def start_session(req: gr.Request):
"""Create session directory"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
"""Clean up session directory"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
try:
shutil.rmtree(user_dir)
except:
pass
def preprocess_image(image: Image.Image) -> Image.Image:
"""Preprocess input image for 3D generation"""
init_pipeline()
processed_image = pipeline.preprocess_image(image)
return processed_image
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
"""Pack Gaussian and mesh state for storage"""
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
"""Unpack Gaussian and mesh state"""
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""Get random seed for generation"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def call_magic_articulate_api(obj_path: str, api_url: str = MAGIC_ARTICULATE_URL) -> Tuple[str, str, str]:
"""
Call MagicArticulate Colab API to generate rigging and skeleton from OBJ mesh
Args:
obj_path: Path to OBJ file
api_url: MagicArticulate Colab gradio URL
Returns:
Tuple of (rig_pred_path, skeleton_obj_path, info_text)
- rig_pred_path: Path to generated rig prediction TXT file
- skeleton_obj_path: Path to generated skeleton OBJ file
- info_text: Info about rigging results
"""
try:
print(f"🦴 Connecting to MagicArticulate API ({api_url})...")
magic_client = Client(api_url)
print("📤 Uploading OBJ to MagicArticulate...")
result = magic_client.predict(
input_mesh=handle_file(obj_path),
api_name="/predict"
)
# MagicArticulate returns (rig_pred.txt, skeleton.obj, normalized_mesh.obj)
rig_pred_file = result[0]
skeleton_file = result[1]
print("✅ MagicArticulate generation successful!")
# Read skeleton info
info_text = "Skeleton generated with hierarchical bone ordering"
if skeleton_file and os.path.exists(skeleton_file):
skeleton_mesh = trimesh.load(skeleton_file, force='mesh')
num_vertices = len(skeleton_mesh.vertices)
info_text = f"Joints: {num_vertices // 2}, Hierarchical structure"
return rig_pred_file, skeleton_file, info_text
except AppError as e:
error_msg = str(e)
print(f"⚠️ MagicArticulate error: {error_msg}")
raise
except Exception as e:
print(f"⚠️ MagicArticulate API error: {str(e)}")
raise
@spaces.GPU(duration=180)
def generate_3d_with_rigging(
image: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[dict, str, str, str, str, str, str]:
"""
Complete pipeline: Image -> 3D Model (TRELLIS) -> OBJ -> Rigging (MagicArticulate)
"""
try:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
# ============ STEP 1: TRELLIS 3D GENERATION ============
print("🎨 Generating 3D model with TRELLIS...")
init_pipeline()
outputs = pipeline.run(
image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# Extract Gaussian and Mesh
gs = outputs['gaussian'][0]
mesh = outputs['mesh'][0]
# ============ STEP 2: RENDER VIDEO ============
print("📹 Rendering 360° preview video...")
video = render_utils.render_video(gs, num_frames=120)['color']
video_geo = render_utils.render_video(mesh, num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
# ============ STEP 3: EXTRACT GLB ============
print("🎁 Extracting GLB with textures...")
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
# ============ STEP 4: CONVERT GLB TO OBJ ============
print("🔄 Converting GLB to OBJ format...")
obj_path = os.path.join(user_dir, "model.obj")
mesh_trimesh = trimesh.load(glb_path, force='mesh')
original_vertices = len(mesh_trimesh.vertices)
original_faces = len(mesh_trimesh.faces)
mesh_trimesh.export(obj_path)
mesh_info = f"""
📊 Mesh Statistics:
• Vertices: {original_vertices:,}
• Faces: {original_faces:,}
• Texture Size: {texture_size}px
• Status: ✓ Ready for rigging
"""
# ============ STEP 5: MAGIC ARTICULATE RIGGING ============
print("🦴 Calling MagicArticulate API for automatic skeleton generation...")
rig_info = ""
rig_file = None
skeleton_file = None
try:
# Call MagicArticulate Colab API
rig_result_path, skeleton_result_path, rig_info_text = call_magic_articulate_api(
obj_path=obj_path,
api_url=MAGIC_ARTICULATE_URL
)
if rig_result_path and os.path.exists(rig_result_path):
# Copy rig prediction file to user directory
rig_file = os.path.join(user_dir, 'rig_pred.txt')
shutil.copy(rig_result_path, rig_file)
if skeleton_result_path and os.path.exists(skeleton_result_path):
# Copy skeleton file to user directory
skeleton_file = os.path.join(user_dir, 'skeleton.obj')
shutil.copy(skeleton_result_path, skeleton_file)
rig_info = f"""✅ MagicArticulate Skeleton Generated:
{rig_info_text}
📥 Downloads:
• rig_pred.txt - Joint positions & bone hierarchy
• skeleton.obj - 3D skeleton visualization
🔧 Import into Blender/Maya for animation
"""
except Exception as e:
print(f"⚠️ MagicArticulate API error: {str(e)}")
# Create error file with instructions
rig_file = os.path.join(user_dir, 'rig_pred.txt')
with open(rig_file, 'w') as f:
f.write(f"MagicArticulate Error: {str(e)}\n\n")
f.write("Workaround: Download OBJ and rig manually in Blender.")
rig_info = f"⚠️ MagicArticulate API unavailable: {str(e)}\n\n**Solution:** Download OBJ and use Blender Rigify add-on"
skeleton_file = None
# ============ STEP 6: PACK RESULTS ============
print("📦 Packaging results...")
state = pack_state(gs, mesh)
torch.cuda.empty_cache()
combined_info = f"""
🎨 TRELLIS Generation:
• Seed: {seed}
• SS Guidance: {ss_guidance_strength}
• SS Steps: {ss_sampling_steps}
• SLAT Guidance: {slat_guidance_strength}
• SLAT Steps: {slat_sampling_steps}
{mesh_info}
{rig_info}
📥 Downloads Available:
✓ Video preview (360° rotation)
✓ GLB file (textured 3D model)
✓ OBJ file (standard 3D format)
✓ Rig prediction (TXT)
✓ Skeleton (OBJ)
🔧 Next Steps:
1. Download OBJ + Skeleton files
2. Import into Blender/Maya/C4D
3. Apply rigging from rig_pred.txt
4. Animate your model
💡 Pro Tips:
• Skeleton shows joint hierarchy visually
• Rig prediction contains exact joint coordinates
• Model is optimized for animation workflow
"""
print("✅ All processing complete!")
return state, video_path, glb_path, obj_path, rig_file, skeleton_file, combined_info
except Exception as e:
import traceback
error_detail = traceback.format_exc()
print(f"❌ Error: {str(e)}")
print(error_detail)
raise gr.Error(f"❌ Pipeline failed: {str(e)}\n\nDetails:\n{error_detail}")
@spaces.GPU
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""Extract Gaussian splatting file from generated model"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
return gaussian_path, gaussian_path
# ============ GRADIO UI ============
with gr.Blocks(title="Image to Rigged 3D Model (MagicArticulate)", delete_cache=(600, 600)) as demo:
gr.Markdown("""
# 🎭 Image → 3D → Rigging (MagicArticulate Pipeline)
**Automated 3D generation with hierarchical skeleton rigging!**
This unified pipeline combines:
- **TRELLIS** (Image-to-3D, Microsoft Research)
- **MagicArticulate** (Auto-skeleton generation, CVPR 2025)
### 🚀 Workflow:
1. 📤 Upload image of object/character
2. 🎨 TRELLIS generates high-quality 3D mesh (GPU)
3. 🔄 Convert to OBJ format
4. 🦴 MagicArticulate generates hierarchical skeleton
5. 💾 Download mesh + rigging + skeleton for animation
### ✨ Benefits:
- ✅ Hierarchical bone ordering for better animation
- ✅ Automatic joint placement and bone connections
- ✅ Production-ready output for Blender/Maya
- ✅ Visual skeleton + rig data included
⏱️ **Estimated time:** 2-5 minutes
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📥 Input")
input_image = gr.Image(
label="Upload Image",
format="png",
image_mode="RGBA",
type="pil",
height=300
)
with gr.Accordion("⚙️ TRELLIS Parameters", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("**Stage 1: Sparse Structure**")
ss_guidance = gr.Slider(
0.0, 10.0,
label="Guidance Strength",
value=7.5,
step=0.1
)
ss_steps = gr.Slider(
1, 50,
label="Sampling Steps",
value=12,
step=1
)
gr.Markdown("**Stage 2: Structured Latent**")
slat_guidance = gr.Slider(
0.0, 10.0,
label="Guidance Strength",
value=3.0,
step=0.1
)
slat_steps = gr.Slider(
1, 50,
label="Sampling Steps",
value=12,
step=1
)
with gr.Accordion("⚙️ Output Settings", open=False):
mesh_simplify = gr.Slider(
0.9, 0.98,
label="Mesh Simplification",
value=0.95,
step=0.01
)
texture_size = gr.Slider(
512, 2048,
label="Texture Size",
value=1024,
step=512
)
generate_btn = gr.Button(
"🚀 Generate Rigged Model",
variant="primary",
size="lg"
)
extract_gs_btn = gr.Button(
"📥 Extract Gaussian (PLY)",
interactive=False
)
with gr.Column(scale=1):
gr.Markdown("### 📤 Outputs")
with gr.Tabs():
with gr.Tab("📹 Preview"):
video_output = gr.Video(
label="360° Preview",
autoplay=True,
loop=True,
height=300
)
with gr.Tab("🎨 3D Viewer"):
model_output = gr.Model3D(
label="GLB Viewer",
height=400
)
with gr.Tab("📦 Files"):
glb_download = gr.DownloadButton(
label="📥 Download GLB",
interactive=False
)
obj_download = gr.DownloadButton(
label="📥 Download OBJ",
interactive=False
)
rig_download = gr.DownloadButton(
label="🦴 Download Rig Prediction (TXT)",
interactive=False
)
skeleton_download = gr.DownloadButton(
label="🦴 Download Skeleton (OBJ)",
interactive=False
)
gs_download = gr.DownloadButton(
label="✨ Download Gaussian (PLY)",
interactive=False
)
with gr.Tab("ℹ️ Info"):
info_output = gr.Textbox(
label="Pipeline Information",
lines=20,
max_lines=30
)
# State management
output_buf = gr.State()
# Event handlers
demo.load(start_session)
demo.unload(end_session)
input_image.upload(
preprocess_image,
inputs=[input_image],
outputs=[input_image],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
generate_3d_with_rigging,
inputs=[
input_image, seed,
ss_guidance, ss_steps,
slat_guidance, slat_steps,
mesh_simplify, texture_size
],
outputs=[output_buf, video_output, model_output, obj_download, rig_download, skeleton_download, info_output],
).then(
lambda: (
gr.Button(interactive=True),
gr.DownloadButton(interactive=True),
gr.DownloadButton(interactive=True),
gr.DownloadButton(interactive=True),
gr.DownloadButton(interactive=True),
),
outputs=[extract_gs_btn, glb_download, obj_download, rig_download, skeleton_download],
)
video_output.clear(
lambda: (
gr.Button(interactive=False),
gr.DownloadButton(interactive=False),
gr.DownloadButton(interactive=False),
gr.DownloadButton(interactive=False),
gr.DownloadButton(interactive=False),
),
outputs=[extract_gs_btn, glb_download, obj_download, rig_download, skeleton_download],
)
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, gs_download],
).then(
lambda: gr.DownloadButton(interactive=True),
outputs=[gs_download],
)
if __name__ == "__main__":
init_pipeline()
demo.launch()