| import gradio as gr |
| import torch |
| import io |
| from PIL import Image |
| import numpy as np |
| import spaces |
| import math |
| import re |
| from einops import rearrange |
| from mmengine.config import Config |
| from src.builder import BUILDER |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| from scripts.camera.cam_dataset import Cam_Generator |
| from scripts.camera.visualization.visualize_batch import make_perspective_figures |
|
|
| from huggingface_hub import snapshot_download |
| import os |
|
|
| NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?" |
| CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL) |
|
|
| def center_crop(image): |
| w, h = image.size |
| s = min(w, h) |
| l = (w - s) // 2 |
| t = (h - s) // 2 |
| return image.crop((l, t, l + s, t + s)) |
|
|
|
|
| |
| config = "configs/pipelines/stage_2_base.py" |
| config = Config.fromfile(config) |
| model = BUILDER.build(config.model).eval() |
| _ = snapshot_download( |
| repo_id="KangLiao/Puffin", |
| repo_type="model", |
| allow_patterns="Puffin-Base.pth", |
| local_dir="checkpoints/", |
| local_dir_use_symlinks=False, |
| revision="main", |
| ) |
| _ = model.load_state_dict(torch.load("checkpoints/Puffin-Base.pth", map_location='cpu'), strict=False) |
| os.remove("checkpoints/Puffin-Base.pth") |
|
|
| _ = snapshot_download( |
| repo_id="wusize/Puffin", |
| repo_type="model", |
| local_dir="checkpoints/", |
| local_dir_use_symlinks=False, |
| revision="main", |
| ) |
| _ = model.vae.load_state_dict(torch.load('checkpoints/vae.pth', map_location='cpu'), strict=True) |
| os.remove('checkpoints/vae.pth') |
|
|
|
|
| if torch.cuda.is_available(): |
| model = model.to(torch.bfloat16).cuda() |
| else: |
| model = model.to(torch.float32) |
|
|
|
|
| def fig_to_image(fig): |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| buf.seek(0) |
| img = Image.open(buf).convert('RGB') |
| buf.close() |
| return img |
|
|
| def extract_up_lat_figs(fig_dict): |
| fig_up, fig_lat = None, None |
| others = {} |
| for k, fig in fig_dict.items(): |
| if ("up_field" in k) and (fig_up is None): |
| fig_up = fig |
| elif ("latitude_field" in k) and (fig_lat is None): |
| fig_lat = fig |
| else: |
| others[k] = fig |
| return fig_up, fig_lat, others |
|
|
|
|
| @torch.inference_mode() |
| @spaces.GPU(duration=120) |
| |
| def camera_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)): |
| |
| torch.cuda.empty_cache() |
| |
| |
| |
| |
| |
| print(torch.cuda.is_available()) |
|
|
| prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).") |
|
|
| image = Image.fromarray(image_src).convert('RGB') |
| image = center_crop(image) |
| image = image.resize((512, 512)) |
| x = torch.from_numpy(np.array(image)).float() |
| x = x / 255.0 |
| x = 2 * x - 1 |
| x = rearrange(x, 'h w c -> c h w') |
|
|
| with torch.no_grad(): |
| outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False) |
|
|
| text = outputs[0] |
| |
| gen = Cam_Generator(mode="base") |
| cam = gen.get_cam(text) |
| |
| bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0 |
| rgb = bgr[:, :, ::-1].copy() |
| image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
| single_batch = {} |
| single_batch["image"] = image_tensor |
| single_batch["up_field"] = cam[:2].unsqueeze(0) |
| single_batch["latitude_field"] = cam[2:].unsqueeze(0) |
|
|
| figs = make_perspective_figures(single_batch, single_batch, n_pairs=1) |
| up_img = lat_img = None |
| for k, fig in figs.items(): |
| if "up_field" in k: |
| up_img = fig_to_image(fig) |
| elif "latitude_field" in k: |
| lat_img = fig_to_image(fig) |
| plt.close(fig) |
|
|
| return text |
|
|
|
|
| @torch.inference_mode() |
| @spaces.GPU(duration=120) |
| def generate_image(prompt_scene, |
| seed=42, |
| roll=0.1, |
| pitch=0.1, |
| fov=1.0, |
| progress=gr.Progress(track_tqdm=True)): |
| |
| torch.cuda.empty_cache() |
| |
| |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| np.random.seed(seed) |
| print(torch.cuda.is_available()) |
| |
| generator = torch.Generator().manual_seed(seed) |
| prompt_camera = ( |
| "The camera parameters (roll, pitch, and field-of-view) are: " |
| f"{roll:.4f}, {pitch:.4f}, {fov:.4f}." |
| ) |
| gen = Cam_Generator() |
| cam_map = gen.get_cam(prompt_camera).to(model.device) |
| cam_map = cam_map / (math.pi / 2) |
| |
| prompt = prompt_scene + " " + prompt_camera |
| print("prompt:", prompt) |
| |
| bsz = 4 |
| with torch.no_grad(): |
| images, output_reasoning = model.generate( |
| prompt=[prompt]*bsz, |
| cfg_prompt=[""]*bsz, |
| pixel_values_init=None, |
| cfg_scale=4.5, |
| num_steps=50, |
| cam_values=[[cam_map]]*bsz, |
| progress_bar=False, |
| reasoning=False, |
| prompt_reasoning=[""]*bsz, |
| generator=generator, |
| height=512, |
| width=512 |
| ) |
|
|
| images = rearrange(images, 'b c h w -> b h w c') |
| images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() |
| ret_images = [Image.fromarray(image) for image in images] |
| return ret_images |
|
|
|
|
| |
| css = ''' |
| .gradio-container {max-width: 960px !important} |
| ''' |
| with gr.Blocks(css=css) as demo: |
| gr.Markdown("# Puffin") |
|
|
| with gr.Tab("Camera-controllable Image Generation"): |
| gr.Markdown(value="## Camera-controllable Image Generation") |
|
|
| prompt_input = gr.Textbox(label="Prompt.") |
|
|
| with gr.Accordion("Camera Parameters", open=True): |
| with gr.Row(): |
| roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value") |
| pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value") |
| fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value") |
| seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42) |
| |
| generation_button = gr.Button("Generate Images") |
| |
| image_output = gr.Gallery(label="Generated Images", columns=4, rows=1) |
| |
| examples_t2i = gr.Examples( |
| label="Prompt examples.", |
| examples=[ |
| "A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.", |
| "A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.", |
| "A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.", |
| "A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.", |
| "A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.", |
| ], |
| inputs=prompt_input, |
| ) |
|
|
| with gr.Tab("Camera Understanding"): |
| gr.Markdown(value="## Camera Understanding") |
| image_input = gr.Image() |
|
|
| understanding_button = gr.Button("Chat") |
| understanding_output = gr.Textbox(label="Response") |
| |
| |
| |
|
|
| with gr.Accordion("Advanced options", open=False): |
| und_seed_input = gr.Number(label="Seed", precision=0, value=42) |
|
|
| examples_inpainting = gr.Examples( |
| label="Camera Understanding examples", |
| examples=[ |
| "assets/1.jpg", |
| "assets/2.jpg", |
| "assets/3.jpg", |
| "assets/4.jpg", |
| "assets/5.jpg", |
| "assets/6.jpg", |
| ], |
| inputs=image_input, |
| ) |
|
|
| generation_button.click( |
| fn=generate_image, |
| inputs=[prompt_input, seed_input, roll, pitch, fov], |
| outputs=image_output |
| ) |
|
|
| understanding_button.click( |
| camera_understanding, |
| inputs=[image_input, und_seed_input], |
| outputs=[understanding_output] |
| ) |
|
|
| demo.launch(share=True) |