Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| class DynamicsVisualizer: | |
| def __init__(self): | |
| device = torch.device("cpu") | |
| self.device = device | |
| self.width = 640 | |
| self.height = 480 | |
| self.vis_cam_id = 1 | |
| self.bg_id = 0 # 0: black, 1: white | |
| self.imgs = None | |
| self.gs_orig = None | |
| self.gs_pred = None | |
| self.actions = None | |
| self.videos = None | |
| self.example_name = None | |
| self.action_name = None | |
| self.form_image_is_set = False | |
| self.form_video_is_set = False | |
| self.form_3dgs_orig_is_set = False | |
| self.form_3dgs_pred_is_set = False | |
| def load_example(self): | |
| example_path = os.path.join('data', self.example_name) | |
| self.imgs = [Image.open(os.path.join(example_path, f'img_{i}.png')) for i in range(4)] | |
| self.gs_orig = os.path.join(example_path, 'gs_orig.splat') | |
| def load_action(self): | |
| action_path = os.path.join('data', self.action_name) | |
| self.imgs = [Image.open(os.path.join(action_path, f'img_{i}.png')) for i in range(4)] | |
| self.videos = [os.path.join(action_path, f'video_{i}.mp4') for i in range(4)] | |
| self.gs_pred = os.path.join(action_path, 'gs_pred.splat') | |
| def reset(self): | |
| self.imgs = None | |
| self.gs_orig = None | |
| self.gs_pred = None | |
| self.actions = None | |
| self.videos = None | |
| self.vis_cam_id = 1 | |
| self.bg_id = 0 # 0: black, 1: white | |
| self.example_name = None | |
| self.action_name = None | |
| form_image = gr.Image(label='Initial state and actions', value=None, width=self.width, height=self.height) | |
| form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height) | |
| form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=None) | |
| form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None) | |
| self.form_image_is_set = False | |
| self.form_video_is_set = False | |
| self.form_3dgs_orig_is_set = False | |
| self.form_3dgs_pred_is_set = False | |
| return form_image, form_video, form_3dgs_orig, form_3dgs_pred | |
| def on_click_set_example(self, state): | |
| self.example_name = f"{int(state['example_id'])}" | |
| self.load_example() | |
| init_image = self.imgs[self.vis_cam_id] | |
| form_image = gr.Image(label='Initial state and actions', value=init_image, width=self.width, height=self.height) | |
| form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height) | |
| form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
| form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None) | |
| self.form_image_is_set = True | |
| self.form_video_is_set = False | |
| self.form_3dgs_orig_is_set = True | |
| self.form_3dgs_pred_is_set = False | |
| return form_image, form_video, form_3dgs_orig, form_3dgs_pred | |
| def on_click_set_action(self, state): | |
| self.action_name = f"{self.example_name}/action-{int(state['action_id'])}" | |
| self.load_action() | |
| action_image = self.imgs[self.vis_cam_id] | |
| form_image = gr.Image(label='Initial state and actions', value=action_image, width=self.width, height=self.height) | |
| self.form_image_is_set = True | |
| return form_image | |
| def on_click_run(self): | |
| form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height) | |
| form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
| self.form_video_is_set = True | |
| self.form_3dgs_pred_is_set = True | |
| return form_video, form_3dgs_pred | |
| def on_click_change_view(self, state): | |
| self.vis_cam_id = int(state['view_id']) | |
| form_image = gr.Image(label='Initial state and actions', value=self.imgs[self.vis_cam_id], width=self.width, height=self.height) | |
| if self.form_video_is_set: | |
| form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height) | |
| else: | |
| form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height) | |
| return form_image, form_video | |
| # def on_click_change_bg(self): | |
| # if self.bg_id == 0: | |
| # self.bg_id = 1 | |
| # else: | |
| # self.bg_id = 0 | |
| # if self.form_3dgs_orig_is_set: | |
| # form_3dgs_orig = gr.Model3D(value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
| # else: | |
| # form_3dgs_orig = gr.Model3D(value=None) | |
| # if self.form_3dgs_pred_is_set: | |
| # form_3dgs_pred = gr.Model3D(value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
| # else: | |
| # form_3dgs_pred = gr.Model3D(value=None) | |
| # return form_3dgs_orig, form_3dgs_pred | |
| def launch(self, share=False): | |
| with gr.Blocks() as app: | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # run_reset = gr.Button('Clear All') | |
| # with gr.Column(scale=2): | |
| # run_set_example = gr.Button('Set Example') | |
| # with gr.Column(scale=2): | |
| # run_set_action = gr.Button('Set Action') | |
| # with gr.Column(scale=2): | |
| # run_run = gr.Button('Run') | |
| # with gr.Row(): | |
| # with gr.Column(scale=1, min_width=20): | |
| # with gr.Row(): | |
| # run_view_0 = gr.Button('View 0') | |
| # with gr.Row(): | |
| # run_view_1 = gr.Button('View 1') | |
| # with gr.Row(): | |
| # run_view_2 = gr.Button('View 2') | |
| # with gr.Row(): | |
| # run_view_3 = gr.Button('View 3') | |
| with gr.Row(): | |
| gr.Markdown("# Dynamic 3D Gaussian Tracking for Graph-Based Neural Dynamics Modeling") | |
| with gr.Row(): | |
| gr.Markdown('Project page: [https://gs-dynamics.github.io/](https://gs-dynamics.github.io/)') | |
| with gr.Row(): | |
| gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 0**: click **Clear All** to clear all window and reset the visualizer.") | |
| with gr.Column(scale=1): | |
| run_reset = gr.Button('Clear All') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 1**: select the object.") | |
| with gr.Column(scale=1): | |
| run_set_example_0 = gr.Button('Rope') | |
| with gr.Column(scale=1): | |
| run_set_example_1 = gr.Button('Rope - Long') | |
| with gr.Column(scale=1): | |
| run_set_example_2 = gr.Button('Toy Animal') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 2**: select the action.") | |
| with gr.Column(scale=1): | |
| run_set_action_0 = gr.Button('Action 1') | |
| with gr.Column(scale=1): | |
| run_set_action_1 = gr.Button('Action 2') | |
| with gr.Column(scale=1): | |
| run_set_action_2 = gr.Button('Action 3') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Step 3**: click **Run** to visualize the predicted video and Splats.") | |
| with gr.Column(scale=1): | |
| run_run = gr.Button('Run') | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=20): | |
| with gr.Row(): | |
| gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown() | |
| # with gr.Row(): | |
| # gr.Markdown() | |
| # with gr.Row(): | |
| # gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown("Our model uses only 4 cameras for reconstructing the Gaussian Splats. Click the buttons below to change the view.") | |
| with gr.Row(): | |
| run_view_0 = gr.Button('View 0') | |
| with gr.Row(): | |
| run_view_1 = gr.Button('View 1') | |
| with gr.Row(): | |
| run_view_2 = gr.Button('View 2') | |
| with gr.Row(): | |
| run_view_3 = gr.Button('View 3') | |
| with gr.Column(scale=4): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| form_image = gr.Image( | |
| label='Initial state and actions', | |
| value=None, | |
| width=self.width, | |
| height=self.height, | |
| ) | |
| with gr.Column(scale=2): | |
| form_video = gr.Video( | |
| label='Predicted video', | |
| value=None, | |
| width=self.width, | |
| height=self.height, | |
| ) | |
| with gr.Row(): | |
| # with gr.Column(scale=1, min_width=20): | |
| # pass | |
| # with gr.Row(): | |
| # change_bg = gr.Button('Black/White Background') | |
| with gr.Column(scale=2): | |
| form_3dgs_orig = gr.Model3D( | |
| label='Original Gaussian Splats', | |
| value=None, | |
| ) | |
| with gr.Column(scale=2): | |
| form_3dgs_pred = gr.Model3D( | |
| label='Predicted Gaussian Splats', | |
| value=None, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("## Notes:") | |
| with gr.Row(): | |
| gr.Markdown("- Due to the computation constraints of Hugging Face Space, all results are precomputed. ") | |
| with gr.Row(): | |
| gr.Markdown("- Training a GS for an object takes around 30 seconds. Prediction typically takes only 1-2 seconds for each push!") | |
| with gr.Row(): | |
| gr.Markdown("- More examples may be added in the future. Stay tuned!") | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # gr.Markdown("You can change the view to any of the 4 cameras.") | |
| # with gr.Column(scale=1): | |
| # run_view_0 = gr.Button('View 1') | |
| # with gr.Column(scale=1): | |
| # run_view_1 = gr.Button('View 2') | |
| # with gr.Column(scale=1): | |
| # run_view_2 = gr.Button('View 3') | |
| # with gr.Column(scale=1): | |
| # run_view_3 = gr.Button('View 4') | |
| # Set up callbacks | |
| run_reset.click(self.reset, | |
| inputs=[], | |
| outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
| run_set_example_0.click(self.on_click_set_example, | |
| inputs=[gr.State({'example_id': 0})], | |
| outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
| run_set_example_1.click(self.on_click_set_example, | |
| inputs=[gr.State({'example_id': 1})], | |
| outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
| run_set_example_2.click(self.on_click_set_example, | |
| inputs=[gr.State({'example_id': 2})], | |
| outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
| run_set_action_0.click(self.on_click_set_action, | |
| inputs=[gr.State({'action_id': 0})], | |
| outputs=[form_image]) | |
| run_set_action_1.click(self.on_click_set_action, | |
| inputs=[gr.State({'action_id': 1})], | |
| outputs=[form_image]) | |
| run_set_action_2.click(self.on_click_set_action, | |
| inputs=[gr.State({'action_id': 2})], | |
| outputs=[form_image]) | |
| run_run.click(self.on_click_run, | |
| inputs=[], | |
| outputs=[form_video, form_3dgs_pred]) | |
| run_view_0.click(self.on_click_change_view, | |
| inputs=[gr.State({'view_id': 1})], | |
| outputs=[form_image, form_video]) | |
| run_view_1.click(self.on_click_change_view, | |
| inputs=[gr.State({'view_id': 2})], | |
| outputs=[form_image, form_video]) | |
| run_view_2.click(self.on_click_change_view, | |
| inputs=[gr.State({'view_id': 3})], | |
| outputs=[form_image, form_video]) | |
| run_view_3.click(self.on_click_change_view, | |
| inputs=[gr.State({'view_id': 0})], | |
| outputs=[form_image, form_video]) | |
| # change_bg.click(self.on_click_change_bg, | |
| # inputs=[], | |
| # outputs=[form_3dgs_orig, form_3dgs_pred]) | |
| app.launch(share=share) | |
| if __name__ == '__main__': | |
| visualizer = DynamicsVisualizer() | |
| visualizer.launch(share=True) | |