|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import src.depth_pro as depth_pro |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import subprocess |
|
|
import spaces |
|
|
import torch |
|
|
import tempfile |
|
|
import os |
|
|
import trimesh |
|
|
import time |
|
|
import timm |
|
|
import subprocess |
|
|
import cv2 |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
print(f"Timm version: {timm.__version__}") |
|
|
|
|
|
|
|
|
subprocess.run(["bash", "get_pretrained_models.sh"]) |
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model, transform = depth_pro.create_model_and_transforms() |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
def resize_image(image_path, max_size=1024): |
|
|
""" |
|
|
Resize the input image to ensure its largest dimension does not exceed max_size. |
|
|
Maintains the aspect ratio and saves the resized image as a temporary PNG file. |
|
|
|
|
|
Args: |
|
|
image_path (str): Path to the input image. |
|
|
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024. |
|
|
|
|
|
Returns: |
|
|
str: Path to the resized temporary image file. |
|
|
""" |
|
|
with Image.open(image_path) as img: |
|
|
|
|
|
ratio = max_size / max(img.size) |
|
|
new_size = tuple([int(x * ratio) for x in img.size]) |
|
|
|
|
|
|
|
|
img = img.resize(new_size, Image.LANCZOS) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: |
|
|
img.save(temp_file, format="PNG") |
|
|
return temp_file.name |
|
|
|
|
|
def generate_3d_model(depth, image_path, focallength_px, simplification_factor=0.8, smoothing_iterations=1, thin_threshold=0.01): |
|
|
""" |
|
|
Generate a textured 3D mesh from the depth map and the original image. |
|
|
""" |
|
|
|
|
|
image = np.array(Image.open(image_path)) |
|
|
|
|
|
|
|
|
if isinstance(depth, torch.Tensor): |
|
|
depth = depth.cpu().numpy() |
|
|
|
|
|
|
|
|
if depth.shape != image.shape[:2]: |
|
|
depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
height, width = depth.shape |
|
|
|
|
|
print(f"3D model generation - Depth shape: {depth.shape}") |
|
|
print(f"3D model generation - Image shape: {image.shape}") |
|
|
|
|
|
|
|
|
fx = fy = float(focallength_px) |
|
|
cx, cy = width / 2, height / 2 |
|
|
|
|
|
|
|
|
u = np.arange(0, width) |
|
|
v = np.arange(0, height) |
|
|
uu, vv = np.meshgrid(u, v) |
|
|
|
|
|
|
|
|
Z = depth.flatten() |
|
|
X = ((uu.flatten() - cx) * Z) / fx |
|
|
Y = ((vv.flatten() - cy) * Z) / fy |
|
|
|
|
|
|
|
|
vertices = np.vstack((X, Y, Z)).T |
|
|
|
|
|
|
|
|
colors = image.reshape(-1, 3) / 255.0 |
|
|
|
|
|
|
|
|
faces = [] |
|
|
for i in range(height - 1): |
|
|
for j in range(width - 1): |
|
|
idx = i * width + j |
|
|
|
|
|
faces.append([idx, idx + width, idx + 1]) |
|
|
|
|
|
faces.append([idx + 1, idx + width, idx + width + 1]) |
|
|
faces = np.array(faces) |
|
|
|
|
|
|
|
|
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors) |
|
|
|
|
|
|
|
|
print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
|
|
|
target_faces = int(len(mesh.faces) * simplification_factor) |
|
|
mesh = mesh.simplify_quadric_decimation(face_count=target_faces) |
|
|
print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
|
|
|
components = mesh.split(only_watertight=False) |
|
|
if len(components) > 1: |
|
|
areas = np.array([c.area for c in components]) |
|
|
mesh = components[np.argmax(areas)] |
|
|
print("After removing small components - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
|
|
|
for _ in range(smoothing_iterations): |
|
|
mesh = mesh.smoothed() |
|
|
print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
|
|
|
mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold) |
|
|
print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
|
|
|
timestamp = int(time.time()) |
|
|
view_model_path = f'view_model_{timestamp}.obj' |
|
|
download_model_path = f'download_model_{timestamp}.obj' |
|
|
mesh.export(view_model_path) |
|
|
mesh.export(download_model_path) |
|
|
return view_model_path, download_model_path |
|
|
|
|
|
def remove_thin_features(mesh, thickness_threshold=0.01): |
|
|
""" |
|
|
Remove thin features from the mesh. |
|
|
""" |
|
|
|
|
|
edges = mesh.edges_unique |
|
|
edge_points = mesh.vertices[edges] |
|
|
edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1) |
|
|
|
|
|
|
|
|
short_edges = edges[edge_lengths < thickness_threshold] |
|
|
|
|
|
|
|
|
for edge in short_edges: |
|
|
try: |
|
|
mesh.collapse_edge(edge) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
mesh.remove_degenerate_faces() |
|
|
|
|
|
return mesh |
|
|
|
|
|
def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold): |
|
|
|
|
|
depth = np.loadtxt(depth_csv, delimiter=',') |
|
|
|
|
|
|
|
|
view_model_path, download_model_path = generate_3d_model( |
|
|
depth, image_path, focallength_px, |
|
|
simplification_factor, smoothing_iterations, thin_threshold |
|
|
) |
|
|
|
|
|
return view_model_path, download_model_path |
|
|
|
|
|
@spaces.GPU(duration=20) |
|
|
def predict_depth(input_image): |
|
|
temp_file = None |
|
|
try: |
|
|
print(f"Input image type: {type(input_image)}") |
|
|
print(f"Input image path: {input_image}") |
|
|
|
|
|
|
|
|
temp_file = resize_image(input_image) |
|
|
print(f"Resized image path: {temp_file}") |
|
|
|
|
|
|
|
|
result = depth_pro.load_rgb(temp_file) |
|
|
|
|
|
if len(result) < 2: |
|
|
raise ValueError(f"Unexpected result from load_rgb: {result}") |
|
|
|
|
|
|
|
|
image = result[0] |
|
|
f_px = result[-1] |
|
|
|
|
|
print(f"Extracted focal length: {f_px}") |
|
|
|
|
|
image = transform(image).to(device) |
|
|
|
|
|
|
|
|
prediction = model.infer(image, f_px=f_px) |
|
|
depth = prediction["depth"] |
|
|
focallength_px = prediction["focallength_px"] |
|
|
|
|
|
|
|
|
if isinstance(depth, torch.Tensor): |
|
|
depth = depth.cpu().numpy() |
|
|
|
|
|
|
|
|
if depth.ndim != 2: |
|
|
depth = depth.squeeze() |
|
|
|
|
|
print(f"Depth map shape: {depth.shape}") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(depth, cmap='gist_rainbow') |
|
|
plt.colorbar(label='Depth [m]') |
|
|
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m') |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
output_path = "depth_map.png" |
|
|
plt.savefig(output_path) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
raw_depth_path = "raw_depth_map.csv" |
|
|
np.savetxt(raw_depth_path, depth, delimiter=',') |
|
|
|
|
|
|
|
|
view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px) |
|
|
|
|
|
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path, temp_file, focallength_px |
|
|
except Exception as e: |
|
|
|
|
|
import traceback |
|
|
error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
print(error_message) |
|
|
return None, error_message, None, None, None, None, None |
|
|
finally: |
|
|
|
|
|
if temp_file and os.path.exists(temp_file): |
|
|
os.remove(temp_file) |
|
|
|
|
|
def get_last_commit_timestamp(): |
|
|
try: |
|
|
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip() |
|
|
return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S") |
|
|
except Exception as e: |
|
|
print(f"{str(e)}") |
|
|
return str(e) |
|
|
|
|
|
|
|
|
last_updated = get_last_commit_timestamp() |
|
|
|
|
|
with gr.Blocks() as iface: |
|
|
gr.Markdown("# DepthPro Demo with 3D Visualization") |
|
|
gr.Markdown( |
|
|
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n" |
|
|
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n" |
|
|
"**Instructions:**\n" |
|
|
"1. Upload an image.\n" |
|
|
"2. The app will predict the depth map, display it, and provide the focal length.\n" |
|
|
"3. Download the raw depth data as a CSV file.\n" |
|
|
"4. View the generated 3D model textured with the original image.\n" |
|
|
"5. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n" |
|
|
"6. Download the 3D model as an OBJ file if desired.\n\n" |
|
|
f"Last updated: {last_updated}" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
input_image = gr.Image(type="filepath", label="Input Image") |
|
|
depth_map = gr.Image(type="filepath", label="Depth Map") |
|
|
|
|
|
focal_length = gr.Textbox(label="Focal Length") |
|
|
raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)") |
|
|
|
|
|
with gr.Row(): |
|
|
view_3d_model = gr.Model3D(label="View 3D Model") |
|
|
download_3d_model = gr.File(label="Download 3D Model (OBJ)") |
|
|
|
|
|
with gr.Row(): |
|
|
simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Simplification Factor") |
|
|
smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=1, step=1, label="Smoothing Iterations") |
|
|
thin_threshold = gr.Slider(minimum=0.001, maximum=0.1, value=0.01, step=0.001, label="Thin Feature Threshold") |
|
|
|
|
|
regenerate_button = gr.Button("Regenerate 3D Model") |
|
|
|
|
|
|
|
|
hidden_depth_csv = gr.State() |
|
|
hidden_image_path = gr.State() |
|
|
hidden_focal_length = gr.State() |
|
|
|
|
|
input_image.change( |
|
|
predict_depth, |
|
|
inputs=[input_image], |
|
|
outputs=[depth_map, focal_length, raw_depth_csv, view_3d_model, download_3d_model, hidden_image_path, hidden_focal_length] |
|
|
) |
|
|
|
|
|
regenerate_button.click( |
|
|
regenerate_3d_model, |
|
|
inputs=[raw_depth_csv, hidden_image_path, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold], |
|
|
outputs=[view_3d_model, download_3d_model] |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch(share=True) |