Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import os | |
| import shutil | |
| os.environ['SPCONV_ALGO'] = 'native' | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| import imageio | |
| from typing import List, Tuple | |
| from PIL import Image | |
| from easydict import EasyDict as edict | |
| # Add missing imports for MagicArticulate API | |
| from gradio_client import Client, handle_file | |
| from gradio_client.exceptions import AppError | |
| # TRELLIS imports | |
| from trellis.pipelines import TrellisImageTo3DPipeline | |
| from trellis.representations import Gaussian, MeshExtractResult | |
| from trellis.utils import render_utils, postprocessing_utils | |
| MAGIC_ARTICULATE_URL = "https://f3fe9e3f800481d9bd.gradio.live" | |
| # Configuration | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| # Initialize TRELLIS pipeline globally | |
| pipeline = None | |
| def init_pipeline(): | |
| """Initialize TRELLIS pipeline on first load""" | |
| global pipeline | |
| if pipeline is None: | |
| print("🔄 Loading TRELLIS pipeline...") | |
| pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") | |
| pipeline.cuda() | |
| # Preload rembg | |
| try: | |
| pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) | |
| except: | |
| pass | |
| print("✅ TRELLIS pipeline loaded!") | |
| def start_session(req: gr.Request): | |
| """Create session directory""" | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| def end_session(req: gr.Request): | |
| """Clean up session directory""" | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| try: | |
| shutil.rmtree(user_dir) | |
| except: | |
| pass | |
| def preprocess_image(image: Image.Image) -> Image.Image: | |
| """Preprocess input image for 3D generation""" | |
| init_pipeline() | |
| processed_image = pipeline.preprocess_image(image) | |
| return processed_image | |
| def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: | |
| """Pack Gaussian and mesh state for storage""" | |
| return { | |
| 'gaussian': { | |
| **gs.init_params, | |
| '_xyz': gs._xyz.cpu().numpy(), | |
| '_features_dc': gs._features_dc.cpu().numpy(), | |
| '_scaling': gs._scaling.cpu().numpy(), | |
| '_rotation': gs._rotation.cpu().numpy(), | |
| '_opacity': gs._opacity.cpu().numpy(), | |
| }, | |
| 'mesh': { | |
| 'vertices': mesh.vertices.cpu().numpy(), | |
| 'faces': mesh.faces.cpu().numpy(), | |
| }, | |
| } | |
| def unpack_state(state: dict) -> Tuple[Gaussian, edict]: | |
| """Unpack Gaussian and mesh state""" | |
| gs = Gaussian( | |
| aabb=state['gaussian']['aabb'], | |
| sh_degree=state['gaussian']['sh_degree'], | |
| mininum_kernel_size=state['gaussian']['mininum_kernel_size'], | |
| scaling_bias=state['gaussian']['scaling_bias'], | |
| opacity_bias=state['gaussian']['opacity_bias'], | |
| scaling_activation=state['gaussian']['scaling_activation'], | |
| ) | |
| gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') | |
| gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') | |
| gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') | |
| gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') | |
| gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') | |
| mesh = edict( | |
| vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), | |
| faces=torch.tensor(state['mesh']['faces'], device='cuda'), | |
| ) | |
| return gs, mesh | |
| def get_seed(randomize_seed: bool, seed: int) -> int: | |
| """Get random seed for generation""" | |
| return np.random.randint(0, MAX_SEED) if randomize_seed else seed | |
| def call_magic_articulate_api(obj_path: str, api_url: str = MAGIC_ARTICULATE_URL) -> Tuple[str, str, str]: | |
| """ | |
| Call MagicArticulate Colab API to generate rigging and skeleton from OBJ mesh | |
| Args: | |
| obj_path: Path to OBJ file | |
| api_url: MagicArticulate Colab gradio URL | |
| Returns: | |
| Tuple of (rig_pred_path, skeleton_obj_path, info_text) | |
| - rig_pred_path: Path to generated rig prediction TXT file | |
| - skeleton_obj_path: Path to generated skeleton OBJ file | |
| - info_text: Info about rigging results | |
| """ | |
| try: | |
| print(f"🦴 Connecting to MagicArticulate API ({api_url})...") | |
| magic_client = Client(api_url) | |
| print("📤 Uploading OBJ to MagicArticulate...") | |
| result = magic_client.predict( | |
| input_mesh=handle_file(obj_path), | |
| api_name="/predict" | |
| ) | |
| # MagicArticulate returns (rig_pred.txt, skeleton.obj, normalized_mesh.obj) | |
| rig_pred_file = result[0] | |
| skeleton_file = result[1] | |
| print("✅ MagicArticulate generation successful!") | |
| # Read skeleton info | |
| info_text = "Skeleton generated with hierarchical bone ordering" | |
| if skeleton_file and os.path.exists(skeleton_file): | |
| skeleton_mesh = trimesh.load(skeleton_file, force='mesh') | |
| num_vertices = len(skeleton_mesh.vertices) | |
| info_text = f"Joints: {num_vertices // 2}, Hierarchical structure" | |
| return rig_pred_file, skeleton_file, info_text | |
| except AppError as e: | |
| error_msg = str(e) | |
| print(f"⚠️ MagicArticulate error: {error_msg}") | |
| raise | |
| except Exception as e: | |
| print(f"⚠️ MagicArticulate API error: {str(e)}") | |
| raise | |
| def generate_3d_with_rigging( | |
| image: Image.Image, | |
| seed: int, | |
| ss_guidance_strength: float, | |
| ss_sampling_steps: int, | |
| slat_guidance_strength: float, | |
| slat_sampling_steps: int, | |
| mesh_simplify: float, | |
| texture_size: int, | |
| req: gr.Request, | |
| ) -> Tuple[dict, str, str, str, str, str, str]: | |
| """ | |
| Complete pipeline: Image -> 3D Model (TRELLIS) -> OBJ -> Rigging (MagicArticulate) | |
| """ | |
| try: | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| # ============ STEP 1: TRELLIS 3D GENERATION ============ | |
| print("🎨 Generating 3D model with TRELLIS...") | |
| init_pipeline() | |
| outputs = pipeline.run( | |
| image, | |
| seed=seed, | |
| formats=["gaussian", "mesh"], | |
| preprocess_image=False, | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "cfg_strength": ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "cfg_strength": slat_guidance_strength, | |
| }, | |
| ) | |
| # Extract Gaussian and Mesh | |
| gs = outputs['gaussian'][0] | |
| mesh = outputs['mesh'][0] | |
| # ============ STEP 2: RENDER VIDEO ============ | |
| print("📹 Rendering 360° preview video...") | |
| video = render_utils.render_video(gs, num_frames=120)['color'] | |
| video_geo = render_utils.render_video(mesh, num_frames=120)['normal'] | |
| video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] | |
| video_path = os.path.join(user_dir, 'sample.mp4') | |
| imageio.mimsave(video_path, video, fps=15) | |
| # ============ STEP 3: EXTRACT GLB ============ | |
| print("🎁 Extracting GLB with textures...") | |
| glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) | |
| glb_path = os.path.join(user_dir, 'sample.glb') | |
| glb.export(glb_path) | |
| # ============ STEP 4: CONVERT GLB TO OBJ ============ | |
| print("🔄 Converting GLB to OBJ format...") | |
| obj_path = os.path.join(user_dir, "model.obj") | |
| mesh_trimesh = trimesh.load(glb_path, force='mesh') | |
| original_vertices = len(mesh_trimesh.vertices) | |
| original_faces = len(mesh_trimesh.faces) | |
| mesh_trimesh.export(obj_path) | |
| mesh_info = f""" | |
| 📊 Mesh Statistics: | |
| • Vertices: {original_vertices:,} | |
| • Faces: {original_faces:,} | |
| • Texture Size: {texture_size}px | |
| • Status: ✓ Ready for rigging | |
| """ | |
| # ============ STEP 5: MAGIC ARTICULATE RIGGING ============ | |
| print("🦴 Calling MagicArticulate API for automatic skeleton generation...") | |
| rig_info = "" | |
| rig_file = None | |
| skeleton_file = None | |
| try: | |
| # Call MagicArticulate Colab API | |
| rig_result_path, skeleton_result_path, rig_info_text = call_magic_articulate_api( | |
| obj_path=obj_path, | |
| api_url=MAGIC_ARTICULATE_URL | |
| ) | |
| if rig_result_path and os.path.exists(rig_result_path): | |
| # Copy rig prediction file to user directory | |
| rig_file = os.path.join(user_dir, 'rig_pred.txt') | |
| shutil.copy(rig_result_path, rig_file) | |
| if skeleton_result_path and os.path.exists(skeleton_result_path): | |
| # Copy skeleton file to user directory | |
| skeleton_file = os.path.join(user_dir, 'skeleton.obj') | |
| shutil.copy(skeleton_result_path, skeleton_file) | |
| rig_info = f"""✅ MagicArticulate Skeleton Generated: | |
| {rig_info_text} | |
| 📥 Downloads: | |
| • rig_pred.txt - Joint positions & bone hierarchy | |
| • skeleton.obj - 3D skeleton visualization | |
| 🔧 Import into Blender/Maya for animation | |
| """ | |
| except Exception as e: | |
| print(f"⚠️ MagicArticulate API error: {str(e)}") | |
| # Create error file with instructions | |
| rig_file = os.path.join(user_dir, 'rig_pred.txt') | |
| with open(rig_file, 'w') as f: | |
| f.write(f"MagicArticulate Error: {str(e)}\n\n") | |
| f.write("Workaround: Download OBJ and rig manually in Blender.") | |
| rig_info = f"⚠️ MagicArticulate API unavailable: {str(e)}\n\n**Solution:** Download OBJ and use Blender Rigify add-on" | |
| skeleton_file = None | |
| # ============ STEP 6: PACK RESULTS ============ | |
| print("📦 Packaging results...") | |
| state = pack_state(gs, mesh) | |
| torch.cuda.empty_cache() | |
| combined_info = f""" | |
| 🎨 TRELLIS Generation: | |
| • Seed: {seed} | |
| • SS Guidance: {ss_guidance_strength} | |
| • SS Steps: {ss_sampling_steps} | |
| • SLAT Guidance: {slat_guidance_strength} | |
| • SLAT Steps: {slat_sampling_steps} | |
| {mesh_info} | |
| {rig_info} | |
| 📥 Downloads Available: | |
| ✓ Video preview (360° rotation) | |
| ✓ GLB file (textured 3D model) | |
| ✓ OBJ file (standard 3D format) | |
| ✓ Rig prediction (TXT) | |
| ✓ Skeleton (OBJ) | |
| 🔧 Next Steps: | |
| 1. Download OBJ + Skeleton files | |
| 2. Import into Blender/Maya/C4D | |
| 3. Apply rigging from rig_pred.txt | |
| 4. Animate your model | |
| 💡 Pro Tips: | |
| • Skeleton shows joint hierarchy visually | |
| • Rig prediction contains exact joint coordinates | |
| • Model is optimized for animation workflow | |
| """ | |
| print("✅ All processing complete!") | |
| return state, video_path, glb_path, obj_path, rig_file, skeleton_file, combined_info | |
| except Exception as e: | |
| import traceback | |
| error_detail = traceback.format_exc() | |
| print(f"❌ Error: {str(e)}") | |
| print(error_detail) | |
| raise gr.Error(f"❌ Pipeline failed: {str(e)}\n\nDetails:\n{error_detail}") | |
| def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: | |
| """Extract Gaussian splatting file from generated model""" | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| gs, _ = unpack_state(state) | |
| gaussian_path = os.path.join(user_dir, 'sample.ply') | |
| gs.save_ply(gaussian_path) | |
| torch.cuda.empty_cache() | |
| return gaussian_path, gaussian_path | |
| # ============ GRADIO UI ============ | |
| with gr.Blocks(title="Image to Rigged 3D Model (MagicArticulate)", delete_cache=(600, 600)) as demo: | |
| gr.Markdown(""" | |
| # 🎭 Image → 3D → Rigging (MagicArticulate Pipeline) | |
| **Automated 3D generation with hierarchical skeleton rigging!** | |
| This unified pipeline combines: | |
| - **TRELLIS** (Image-to-3D, Microsoft Research) | |
| - **MagicArticulate** (Auto-skeleton generation, CVPR 2025) | |
| ### 🚀 Workflow: | |
| 1. 📤 Upload image of object/character | |
| 2. 🎨 TRELLIS generates high-quality 3D mesh (GPU) | |
| 3. 🔄 Convert to OBJ format | |
| 4. 🦴 MagicArticulate generates hierarchical skeleton | |
| 5. 💾 Download mesh + rigging + skeleton for animation | |
| ### ✨ Benefits: | |
| - ✅ Hierarchical bone ordering for better animation | |
| - ✅ Automatic joint placement and bone connections | |
| - ✅ Production-ready output for Blender/Maya | |
| - ✅ Visual skeleton + rig data included | |
| ⏱️ **Estimated time:** 2-5 minutes | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📥 Input") | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| format="png", | |
| image_mode="RGBA", | |
| type="pil", | |
| height=300 | |
| ) | |
| with gr.Accordion("⚙️ TRELLIS Parameters", open=False): | |
| seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| gr.Markdown("**Stage 1: Sparse Structure**") | |
| ss_guidance = gr.Slider( | |
| 0.0, 10.0, | |
| label="Guidance Strength", | |
| value=7.5, | |
| step=0.1 | |
| ) | |
| ss_steps = gr.Slider( | |
| 1, 50, | |
| label="Sampling Steps", | |
| value=12, | |
| step=1 | |
| ) | |
| gr.Markdown("**Stage 2: Structured Latent**") | |
| slat_guidance = gr.Slider( | |
| 0.0, 10.0, | |
| label="Guidance Strength", | |
| value=3.0, | |
| step=0.1 | |
| ) | |
| slat_steps = gr.Slider( | |
| 1, 50, | |
| label="Sampling Steps", | |
| value=12, | |
| step=1 | |
| ) | |
| with gr.Accordion("⚙️ Output Settings", open=False): | |
| mesh_simplify = gr.Slider( | |
| 0.9, 0.98, | |
| label="Mesh Simplification", | |
| value=0.95, | |
| step=0.01 | |
| ) | |
| texture_size = gr.Slider( | |
| 512, 2048, | |
| label="Texture Size", | |
| value=1024, | |
| step=512 | |
| ) | |
| generate_btn = gr.Button( | |
| "🚀 Generate Rigged Model", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| extract_gs_btn = gr.Button( | |
| "📥 Extract Gaussian (PLY)", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 Outputs") | |
| with gr.Tabs(): | |
| with gr.Tab("📹 Preview"): | |
| video_output = gr.Video( | |
| label="360° Preview", | |
| autoplay=True, | |
| loop=True, | |
| height=300 | |
| ) | |
| with gr.Tab("🎨 3D Viewer"): | |
| model_output = gr.Model3D( | |
| label="GLB Viewer", | |
| height=400 | |
| ) | |
| with gr.Tab("📦 Files"): | |
| glb_download = gr.DownloadButton( | |
| label="📥 Download GLB", | |
| interactive=False | |
| ) | |
| obj_download = gr.DownloadButton( | |
| label="📥 Download OBJ", | |
| interactive=False | |
| ) | |
| rig_download = gr.DownloadButton( | |
| label="🦴 Download Rig Prediction (TXT)", | |
| interactive=False | |
| ) | |
| skeleton_download = gr.DownloadButton( | |
| label="🦴 Download Skeleton (OBJ)", | |
| interactive=False | |
| ) | |
| gs_download = gr.DownloadButton( | |
| label="✨ Download Gaussian (PLY)", | |
| interactive=False | |
| ) | |
| with gr.Tab("ℹ️ Info"): | |
| info_output = gr.Textbox( | |
| label="Pipeline Information", | |
| lines=20, | |
| max_lines=30 | |
| ) | |
| # State management | |
| output_buf = gr.State() | |
| # Event handlers | |
| demo.load(start_session) | |
| demo.unload(end_session) | |
| input_image.upload( | |
| preprocess_image, | |
| inputs=[input_image], | |
| outputs=[input_image], | |
| ) | |
| generate_btn.click( | |
| get_seed, | |
| inputs=[randomize_seed, seed], | |
| outputs=[seed], | |
| ).then( | |
| generate_3d_with_rigging, | |
| inputs=[ | |
| input_image, seed, | |
| ss_guidance, ss_steps, | |
| slat_guidance, slat_steps, | |
| mesh_simplify, texture_size | |
| ], | |
| outputs=[output_buf, video_output, model_output, obj_download, rig_download, skeleton_download, info_output], | |
| ).then( | |
| lambda: ( | |
| gr.Button(interactive=True), | |
| gr.DownloadButton(interactive=True), | |
| gr.DownloadButton(interactive=True), | |
| gr.DownloadButton(interactive=True), | |
| gr.DownloadButton(interactive=True), | |
| ), | |
| outputs=[extract_gs_btn, glb_download, obj_download, rig_download, skeleton_download], | |
| ) | |
| video_output.clear( | |
| lambda: ( | |
| gr.Button(interactive=False), | |
| gr.DownloadButton(interactive=False), | |
| gr.DownloadButton(interactive=False), | |
| gr.DownloadButton(interactive=False), | |
| gr.DownloadButton(interactive=False), | |
| ), | |
| outputs=[extract_gs_btn, glb_download, obj_download, rig_download, skeleton_download], | |
| ) | |
| extract_gs_btn.click( | |
| extract_gaussian, | |
| inputs=[output_buf], | |
| outputs=[model_output, gs_download], | |
| ).then( | |
| lambda: gr.DownloadButton(interactive=True), | |
| outputs=[gs_download], | |
| ) | |
| if __name__ == "__main__": | |
| init_pipeline() | |
| demo.launch() |