Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
b800513
1
Parent(s):
0533bc0
update
Browse files- app.py +15 -20
- app_style.py +1 -1
- common.py +59 -2
- embodied_gen/data/asset_converter.py +492 -186
- embodied_gen/data/backproject.py +1 -22
- embodied_gen/data/backproject_v2.py +175 -25
- embodied_gen/data/backproject_v3.py +557 -0
- embodied_gen/data/convex_decomposer.py +77 -2
- embodied_gen/data/differentiable_render.py +66 -31
- embodied_gen/data/mesh_operator.py +0 -4
- embodied_gen/data/utils.py +34 -3
- embodied_gen/envs/pick_embodiedgen.py +195 -0
- embodied_gen/models/delight_model.py +1 -1
- embodied_gen/models/gs_model.py +21 -2
- embodied_gen/models/image_comm_model.py +138 -0
- embodied_gen/models/layout.py +82 -0
- embodied_gen/models/segment_model.py +137 -8
- embodied_gen/models/sr_model.py +62 -1
- embodied_gen/models/text_model.py +59 -0
- embodied_gen/models/texture_model.py +50 -0
- embodied_gen/scripts/render_gs.py +1 -16
- embodied_gen/trainer/pono2mesh_trainer.py +141 -8
- embodied_gen/utils/enum.py +137 -0
- embodied_gen/utils/geometry.py +133 -23
- embodied_gen/utils/gpt_clients.py +41 -11
- embodied_gen/utils/process_media.py +160 -11
- embodied_gen/utils/simulation.py +141 -24
- embodied_gen/utils/tags.py +1 -1
- embodied_gen/validators/aesthetic_predictor.py +14 -6
- embodied_gen/validators/quality_checkers.py +70 -1
- embodied_gen/validators/urdf_convertor.py +73 -5
- thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py +1 -1
app.py
CHANGED
|
@@ -27,7 +27,7 @@ from common import (
|
|
| 27 |
VERSION,
|
| 28 |
active_btn_by_content,
|
| 29 |
end_session,
|
| 30 |
-
|
| 31 |
extract_urdf,
|
| 32 |
get_seed,
|
| 33 |
image_to_3d,
|
|
@@ -179,17 +179,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 179 |
)
|
| 180 |
|
| 181 |
generate_btn = gr.Button(
|
| 182 |
-
"🚀 1. Generate(~
|
| 183 |
variant="primary",
|
| 184 |
interactive=False,
|
| 185 |
)
|
| 186 |
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
| 187 |
-
with gr.Row():
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
with gr.Accordion(
|
| 194 |
label="Enter Asset Attributes(optional)", open=False
|
| 195 |
):
|
|
@@ -207,7 +207,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 207 |
)
|
| 208 |
with gr.Row():
|
| 209 |
extract_urdf_btn = gr.Button(
|
| 210 |
-
"🧩
|
| 211 |
variant="primary",
|
| 212 |
interactive=False,
|
| 213 |
)
|
|
@@ -230,7 +230,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 230 |
)
|
| 231 |
with gr.Row():
|
| 232 |
download_urdf = gr.DownloadButton(
|
| 233 |
-
label="⬇️
|
| 234 |
variant="primary",
|
| 235 |
interactive=False,
|
| 236 |
)
|
|
@@ -326,7 +326,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 326 |
image_prompt.change(
|
| 327 |
lambda: tuple(
|
| 328 |
[
|
| 329 |
-
gr.Button(interactive=False),
|
| 330 |
gr.Button(interactive=False),
|
| 331 |
gr.Button(interactive=False),
|
| 332 |
None,
|
|
@@ -344,7 +344,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 344 |
]
|
| 345 |
),
|
| 346 |
outputs=[
|
| 347 |
-
extract_rep3d_btn,
|
| 348 |
extract_urdf_btn,
|
| 349 |
download_urdf,
|
| 350 |
model_output_gs,
|
|
@@ -375,7 +375,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 375 |
image_prompt_sam.change(
|
| 376 |
lambda: tuple(
|
| 377 |
[
|
| 378 |
-
gr.Button(interactive=False),
|
| 379 |
gr.Button(interactive=False),
|
| 380 |
gr.Button(interactive=False),
|
| 381 |
None,
|
|
@@ -394,7 +394,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 394 |
]
|
| 395 |
),
|
| 396 |
outputs=[
|
| 397 |
-
extract_rep3d_btn,
|
| 398 |
extract_urdf_btn,
|
| 399 |
download_urdf,
|
| 400 |
model_output_gs,
|
|
@@ -447,12 +447,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 447 |
],
|
| 448 |
outputs=[output_buf, video_output],
|
| 449 |
).success(
|
| 450 |
-
|
| 451 |
-
outputs=[extract_rep3d_btn],
|
| 452 |
-
)
|
| 453 |
-
|
| 454 |
-
extract_rep3d_btn.click(
|
| 455 |
-
extract_3d_representations_v2,
|
| 456 |
inputs=[
|
| 457 |
output_buf,
|
| 458 |
project_delight,
|
|
|
|
| 27 |
VERSION,
|
| 28 |
active_btn_by_content,
|
| 29 |
end_session,
|
| 30 |
+
extract_3d_representations_v3,
|
| 31 |
extract_urdf,
|
| 32 |
get_seed,
|
| 33 |
image_to_3d,
|
|
|
|
| 179 |
)
|
| 180 |
|
| 181 |
generate_btn = gr.Button(
|
| 182 |
+
"🚀 1. Generate(~2 mins)",
|
| 183 |
variant="primary",
|
| 184 |
interactive=False,
|
| 185 |
)
|
| 186 |
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
| 187 |
+
# with gr.Row():
|
| 188 |
+
# extract_rep3d_btn = gr.Button(
|
| 189 |
+
# "🔍 2. Extract 3D Representation(~2 mins)",
|
| 190 |
+
# variant="primary",
|
| 191 |
+
# interactive=False,
|
| 192 |
+
# )
|
| 193 |
with gr.Accordion(
|
| 194 |
label="Enter Asset Attributes(optional)", open=False
|
| 195 |
):
|
|
|
|
| 207 |
)
|
| 208 |
with gr.Row():
|
| 209 |
extract_urdf_btn = gr.Button(
|
| 210 |
+
"🧩 2. Extract URDF with physics(~1 mins)",
|
| 211 |
variant="primary",
|
| 212 |
interactive=False,
|
| 213 |
)
|
|
|
|
| 230 |
)
|
| 231 |
with gr.Row():
|
| 232 |
download_urdf = gr.DownloadButton(
|
| 233 |
+
label="⬇️ 3. Download URDF",
|
| 234 |
variant="primary",
|
| 235 |
interactive=False,
|
| 236 |
)
|
|
|
|
| 326 |
image_prompt.change(
|
| 327 |
lambda: tuple(
|
| 328 |
[
|
| 329 |
+
# gr.Button(interactive=False),
|
| 330 |
gr.Button(interactive=False),
|
| 331 |
gr.Button(interactive=False),
|
| 332 |
None,
|
|
|
|
| 344 |
]
|
| 345 |
),
|
| 346 |
outputs=[
|
| 347 |
+
# extract_rep3d_btn,
|
| 348 |
extract_urdf_btn,
|
| 349 |
download_urdf,
|
| 350 |
model_output_gs,
|
|
|
|
| 375 |
image_prompt_sam.change(
|
| 376 |
lambda: tuple(
|
| 377 |
[
|
| 378 |
+
# gr.Button(interactive=False),
|
| 379 |
gr.Button(interactive=False),
|
| 380 |
gr.Button(interactive=False),
|
| 381 |
None,
|
|
|
|
| 394 |
]
|
| 395 |
),
|
| 396 |
outputs=[
|
| 397 |
+
# extract_rep3d_btn,
|
| 398 |
extract_urdf_btn,
|
| 399 |
download_urdf,
|
| 400 |
model_output_gs,
|
|
|
|
| 447 |
],
|
| 448 |
outputs=[output_buf, video_output],
|
| 449 |
).success(
|
| 450 |
+
extract_3d_representations_v3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
inputs=[
|
| 452 |
output_buf,
|
| 453 |
project_delight,
|
app_style.py
CHANGED
|
@@ -4,7 +4,7 @@ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
|
| 4 |
lighting_css = """
|
| 5 |
<style>
|
| 6 |
#lighter_mesh canvas {
|
| 7 |
-
filter: brightness(
|
| 8 |
}
|
| 9 |
</style>
|
| 10 |
"""
|
|
|
|
| 4 |
lighting_css = """
|
| 5 |
<style>
|
| 6 |
#lighter_mesh canvas {
|
| 7 |
+
filter: brightness(2.0) !important;
|
| 8 |
}
|
| 9 |
</style>
|
| 10 |
"""
|
common.py
CHANGED
|
@@ -32,6 +32,7 @@ import trimesh
|
|
| 32 |
from easydict import EasyDict as edict
|
| 33 |
from PIL import Image
|
| 34 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
|
|
|
| 35 |
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
| 36 |
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
| 37 |
from embodied_gen.models.delight_model import DelightingModel
|
|
@@ -131,8 +132,8 @@ def patched_setup_functions(self):
|
|
| 131 |
Gaussian.setup_functions = patched_setup_functions
|
| 132 |
|
| 133 |
|
| 134 |
-
DELIGHT = DelightingModel()
|
| 135 |
-
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 136 |
# IMAGESR_MODEL = ImageStableSR()
|
| 137 |
if os.getenv("GRADIO_APP") == "imageto3d":
|
| 138 |
RBG_REMOVER = RembgRemover()
|
|
@@ -169,6 +170,8 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
|
| 169 |
)
|
| 170 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 171 |
elif os.getenv("GRADIO_APP") == "texture_edit":
|
|
|
|
|
|
|
| 172 |
PIPELINE_IP = build_texture_gen_pipe(
|
| 173 |
base_ckpt_dir="./weights",
|
| 174 |
ip_adapt_scale=0.7,
|
|
@@ -512,6 +515,60 @@ def extract_3d_representations_v2(
|
|
| 512 |
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
| 513 |
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
def extract_urdf(
|
| 516 |
gs_path: str,
|
| 517 |
mesh_obj_path: str,
|
|
|
|
| 32 |
from easydict import EasyDict as edict
|
| 33 |
from PIL import Image
|
| 34 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
| 35 |
+
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
| 36 |
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
| 37 |
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
| 38 |
from embodied_gen.models.delight_model import DelightingModel
|
|
|
|
| 132 |
Gaussian.setup_functions = patched_setup_functions
|
| 133 |
|
| 134 |
|
| 135 |
+
# DELIGHT = DelightingModel()
|
| 136 |
+
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 137 |
# IMAGESR_MODEL = ImageStableSR()
|
| 138 |
if os.getenv("GRADIO_APP") == "imageto3d":
|
| 139 |
RBG_REMOVER = RembgRemover()
|
|
|
|
| 170 |
)
|
| 171 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 172 |
elif os.getenv("GRADIO_APP") == "texture_edit":
|
| 173 |
+
DELIGHT = DelightingModel()
|
| 174 |
+
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 175 |
PIPELINE_IP = build_texture_gen_pipe(
|
| 176 |
base_ckpt_dir="./weights",
|
| 177 |
ip_adapt_scale=0.7,
|
|
|
|
| 515 |
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
| 516 |
|
| 517 |
|
| 518 |
+
def extract_3d_representations_v3(
|
| 519 |
+
state: dict,
|
| 520 |
+
enable_delight: bool,
|
| 521 |
+
texture_size: int,
|
| 522 |
+
req: gr.Request,
|
| 523 |
+
):
|
| 524 |
+
output_root = TMP_DIR
|
| 525 |
+
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 526 |
+
gs_model, mesh_model = unpack_state(state, device="cpu")
|
| 527 |
+
|
| 528 |
+
filename = "sample"
|
| 529 |
+
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
|
| 530 |
+
gs_model.save_ply(gs_path)
|
| 531 |
+
|
| 532 |
+
# Rotate mesh and GS by 90 degrees around Z-axis.
|
| 533 |
+
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
| 534 |
+
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
| 535 |
+
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
| 536 |
+
|
| 537 |
+
# Addtional rotation for GS to align mesh.
|
| 538 |
+
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
| 539 |
+
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
| 540 |
+
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
| 541 |
+
GaussianOperator.resave_ply(
|
| 542 |
+
in_ply=gs_path,
|
| 543 |
+
out_ply=aligned_gs_path,
|
| 544 |
+
instance_pose=pose,
|
| 545 |
+
device="cpu",
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
mesh = trimesh.Trimesh(
|
| 549 |
+
vertices=mesh_model.vertices.cpu().numpy(),
|
| 550 |
+
faces=mesh_model.faces.cpu().numpy(),
|
| 551 |
+
)
|
| 552 |
+
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
|
| 553 |
+
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
| 554 |
+
|
| 555 |
+
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
|
| 556 |
+
mesh.export(mesh_obj_path)
|
| 557 |
+
|
| 558 |
+
mesh = backproject_api_v3(
|
| 559 |
+
gs_path=aligned_gs_path,
|
| 560 |
+
mesh_path=mesh_obj_path,
|
| 561 |
+
output_path=mesh_obj_path,
|
| 562 |
+
skip_fix_mesh=False,
|
| 563 |
+
texture_size=texture_size,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
|
| 567 |
+
mesh.export(mesh_glb_path)
|
| 568 |
+
|
| 569 |
+
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
| 570 |
+
|
| 571 |
+
|
| 572 |
def extract_urdf(
|
| 573 |
gs_path: str,
|
| 574 |
mesh_obj_path: str,
|
embodied_gen/data/asset_converter.py
CHANGED
|
@@ -4,12 +4,12 @@ import logging
|
|
| 4 |
import os
|
| 5 |
import xml.etree.ElementTree as ET
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
-
from dataclasses import dataclass
|
| 8 |
from glob import glob
|
| 9 |
-
from shutil import copy
|
| 10 |
|
| 11 |
import trimesh
|
| 12 |
from scipy.spatial.transform import Rotation
|
|
|
|
| 13 |
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
|
@@ -17,54 +17,157 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
|
| 18 |
__all__ = [
|
| 19 |
"AssetConverterFactory",
|
| 20 |
-
"AssetType",
|
| 21 |
"MeshtoMJCFConverter",
|
| 22 |
"MeshtoUSDConverter",
|
| 23 |
"URDFtoUSDConverter",
|
|
|
|
|
|
|
| 24 |
]
|
| 25 |
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
USD = "usd"
|
| 33 |
-
URDF = "urdf"
|
| 34 |
-
MESH = "mesh"
|
| 35 |
|
| 36 |
|
| 37 |
class AssetConverterBase(ABC):
|
| 38 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
@abstractmethod
|
| 41 |
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
pass
|
| 43 |
|
| 44 |
def transform_mesh(
|
| 45 |
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
| 46 |
) -> None:
|
| 47 |
-
"""Apply transform to
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
| 50 |
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
| 51 |
offset = list(map(float, mesh_origin.get("xyz").split(" ")))
|
| 52 |
-
mesh.vertices = (mesh.vertices @ rotation.as_matrix().T) + offset
|
| 53 |
-
|
| 54 |
os.makedirs(os.path.dirname(output_mesh), exist_ok=True)
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
return
|
| 58 |
|
| 59 |
def __enter__(self):
|
|
|
|
| 60 |
return self
|
| 61 |
|
| 62 |
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
| 63 |
return False
|
| 64 |
|
| 65 |
|
| 66 |
class MeshtoMJCFConverter(AssetConverterBase):
|
| 67 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def __init__(
|
| 70 |
self,
|
|
@@ -73,6 +176,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|
| 73 |
self.kwargs = kwargs
|
| 74 |
|
| 75 |
def _copy_asset_file(self, src: str, dst: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if os.path.exists(dst):
|
| 77 |
return
|
| 78 |
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
|
@@ -90,37 +199,66 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|
| 90 |
material: ET.Element | None = None,
|
| 91 |
is_collision: bool = False,
|
| 92 |
) -> None:
|
| 93 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
element = link.find(tag)
|
| 95 |
geometry = element.find("geometry")
|
| 96 |
mesh = geometry.find("mesh")
|
| 97 |
filename = mesh.get("filename")
|
| 98 |
scale = mesh.get("scale", "1.0 1.0 1.0")
|
| 99 |
-
|
| 100 |
-
mesh_asset = ET.SubElement(
|
| 101 |
-
mujoco_element, "mesh", name=mesh_name, file=filename, scale=scale
|
| 102 |
-
)
|
| 103 |
-
geom = ET.SubElement(body, "geom", type="mesh", mesh=mesh_name)
|
| 104 |
-
|
| 105 |
-
self._copy_asset_file(
|
| 106 |
-
f"{input_dir}/{filename}",
|
| 107 |
-
f"{output_dir}/{filename}",
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
# Preprocess the mesh by applying rotation.
|
| 111 |
input_mesh = f"{input_dir}/{filename}"
|
| 112 |
output_mesh = f"{output_dir}/{filename}"
|
|
|
|
|
|
|
| 113 |
mesh_origin = element.find("origin")
|
| 114 |
if mesh_origin is not None:
|
| 115 |
self.transform_mesh(input_mesh, output_mesh, mesh_origin)
|
| 116 |
|
| 117 |
-
if material is not None:
|
| 118 |
-
geom.set("material", material.get("name"))
|
| 119 |
-
|
| 120 |
if is_collision:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def add_materials(
|
| 126 |
self,
|
|
@@ -132,31 +270,52 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|
| 132 |
name: str,
|
| 133 |
reflectance: float = 0.2,
|
| 134 |
) -> ET.Element:
|
| 135 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
element = link.find(tag)
|
| 137 |
geometry = element.find("geometry")
|
| 138 |
mesh = geometry.find("mesh")
|
| 139 |
filename = mesh.get("filename")
|
| 140 |
dirname = os.path.dirname(filename)
|
| 141 |
-
|
| 142 |
-
material = ET.SubElement(
|
| 143 |
-
mujoco_element,
|
| 144 |
-
"material",
|
| 145 |
-
name=f"material_{name}",
|
| 146 |
-
texture=f"texture_{name}",
|
| 147 |
-
reflectance=str(reflectance),
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
for path in glob(f"{input_dir}/{dirname}/*.png"):
|
| 151 |
file_name = os.path.basename(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
self._copy_asset_file(
|
| 153 |
path,
|
| 154 |
f"{output_dir}/{dirname}/{file_name}",
|
| 155 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
ET.SubElement(
|
| 157 |
mujoco_element,
|
| 158 |
"texture",
|
| 159 |
-
name=
|
| 160 |
type="2d",
|
| 161 |
file=f"{dirname}/{file_name}",
|
| 162 |
)
|
|
@@ -164,7 +323,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|
| 164 |
return material
|
| 165 |
|
| 166 |
def convert(self, urdf_path: str, mjcf_path: str):
|
| 167 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
tree = ET.parse(urdf_path)
|
| 169 |
root = tree.getroot()
|
| 170 |
|
|
@@ -188,6 +352,7 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|
| 188 |
output_dir,
|
| 189 |
name=str(idx),
|
| 190 |
)
|
|
|
|
| 191 |
self.add_geometry(
|
| 192 |
mujoco_asset,
|
| 193 |
link,
|
|
@@ -217,58 +382,22 @@ class MeshtoMJCFConverter(AssetConverterBase):
|
|
| 217 |
|
| 218 |
|
| 219 |
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
| 220 |
-
"""
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
mujoco_element: ET.Element,
|
| 225 |
-
link: ET.Element,
|
| 226 |
-
tag: str,
|
| 227 |
-
input_dir: str,
|
| 228 |
-
output_dir: str,
|
| 229 |
-
name: str,
|
| 230 |
-
reflectance: float = 0.2,
|
| 231 |
-
) -> ET.Element:
|
| 232 |
-
"""Add materials to the MJCF asset from the URDF link."""
|
| 233 |
-
element = link.find(tag)
|
| 234 |
-
geometry = element.find("geometry")
|
| 235 |
-
mesh = geometry.find("mesh")
|
| 236 |
-
filename = mesh.get("filename")
|
| 237 |
-
dirname = os.path.dirname(filename)
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
file_name = os.path.basename(path)
|
| 242 |
-
self._copy_asset_file(
|
| 243 |
-
path,
|
| 244 |
-
f"{output_dir}/{dirname}/{file_name}",
|
| 245 |
-
)
|
| 246 |
-
texture_name = f"texture_{name}_{os.path.splitext(file_name)[0]}"
|
| 247 |
-
ET.SubElement(
|
| 248 |
-
mujoco_element,
|
| 249 |
-
"texture",
|
| 250 |
-
name=texture_name,
|
| 251 |
-
type="2d",
|
| 252 |
-
file=f"{dirname}/{file_name}",
|
| 253 |
-
)
|
| 254 |
-
if "diffuse" in file_name.lower():
|
| 255 |
-
diffuse_texture = texture_name
|
| 256 |
-
|
| 257 |
-
if diffuse_texture is None:
|
| 258 |
-
return None
|
| 259 |
-
|
| 260 |
-
material = ET.SubElement(
|
| 261 |
-
mujoco_element,
|
| 262 |
-
"material",
|
| 263 |
-
name=f"material_{name}",
|
| 264 |
-
texture=diffuse_texture,
|
| 265 |
-
reflectance=str(reflectance),
|
| 266 |
-
)
|
| 267 |
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
tree = ET.parse(urdf_path)
|
| 273 |
root = tree.getroot()
|
| 274 |
|
|
@@ -281,18 +410,12 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
|
| 281 |
output_dir = os.path.dirname(mjcf_path)
|
| 282 |
os.makedirs(output_dir, exist_ok=True)
|
| 283 |
|
| 284 |
-
# Create a dictionary to store body elements for each link
|
| 285 |
body_dict = {}
|
| 286 |
-
|
| 287 |
-
# Process all links first
|
| 288 |
for idx, link in enumerate(root.findall("link")):
|
| 289 |
link_name = link.get("name", f"unnamed_link_{idx}")
|
| 290 |
body = ET.SubElement(mujoco_worldbody, "body", name=link_name)
|
| 291 |
body_dict[link_name] = body
|
| 292 |
-
|
| 293 |
-
# Add materials and geometry
|
| 294 |
-
visual_element = link.find("visual")
|
| 295 |
-
if visual_element is not None:
|
| 296 |
material = self.add_materials(
|
| 297 |
mujoco_asset,
|
| 298 |
link,
|
|
@@ -311,9 +434,7 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
|
| 311 |
f"visual_mesh_{idx}",
|
| 312 |
material,
|
| 313 |
)
|
| 314 |
-
|
| 315 |
-
collision_element = link.find("collision")
|
| 316 |
-
if collision_element is not None:
|
| 317 |
self.add_geometry(
|
| 318 |
mujoco_asset,
|
| 319 |
link,
|
|
@@ -329,41 +450,27 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
|
| 329 |
for joint in root.findall("joint"):
|
| 330 |
joint_type = joint.get("type")
|
| 331 |
if joint_type != "fixed":
|
| 332 |
-
logger.warning(
|
| 333 |
-
f"Skipping non-fixed joint: {joint.get('name')}"
|
| 334 |
-
)
|
| 335 |
continue
|
| 336 |
|
| 337 |
parent_link = joint.find("parent").get("link")
|
| 338 |
child_link = joint.find("child").get("link")
|
| 339 |
origin = joint.find("origin")
|
| 340 |
-
|
| 341 |
if parent_link not in body_dict or child_link not in body_dict:
|
| 342 |
logger.warning(
|
| 343 |
f"Parent or child link not found for joint: {joint.get('name')}"
|
| 344 |
)
|
| 345 |
continue
|
| 346 |
|
| 347 |
-
# Move child body under parent body in MJCF hierarchy
|
| 348 |
child_body = body_dict[child_link]
|
| 349 |
mujoco_worldbody.remove(child_body)
|
| 350 |
parent_body = body_dict[parent_link]
|
| 351 |
parent_body.append(child_body)
|
| 352 |
-
|
| 353 |
-
# Apply joint origin transformation to child body
|
| 354 |
if origin is not None:
|
| 355 |
xyz = origin.get("xyz", "0 0 0")
|
| 356 |
rpy = origin.get("rpy", "0 0 0")
|
| 357 |
child_body.set("pos", xyz)
|
| 358 |
-
|
| 359 |
-
rpy_floats = list(map(float, rpy.split()))
|
| 360 |
-
rotation = Rotation.from_euler(
|
| 361 |
-
"xyz", rpy_floats, degrees=False
|
| 362 |
-
)
|
| 363 |
-
euler_deg = rotation.as_euler("xyz", degrees=True)
|
| 364 |
-
child_body.set(
|
| 365 |
-
"euler", f"{euler_deg[0]} {euler_deg[1]} {euler_deg[2]}"
|
| 366 |
-
)
|
| 367 |
|
| 368 |
tree = ET.ElementTree(mujoco_struct)
|
| 369 |
ET.indent(tree, space=" ", level=0)
|
|
@@ -374,11 +481,15 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
|
| 374 |
|
| 375 |
|
| 376 |
class MeshtoUSDConverter(AssetConverterBase):
|
| 377 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
DEFAULT_BIND_APIS = [
|
| 380 |
"MaterialBindingAPI",
|
| 381 |
"PhysicsMeshCollisionAPI",
|
|
|
|
| 382 |
"PhysicsCollisionAPI",
|
| 383 |
"PhysxCollisionAPI",
|
| 384 |
"PhysicsMassAPI",
|
|
@@ -393,41 +504,65 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|
| 393 |
simulation_app=None,
|
| 394 |
**kwargs,
|
| 395 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
self.usd_parms = dict(
|
| 397 |
force_usd_conversion=force_usd_conversion,
|
| 398 |
make_instanceable=make_instanceable,
|
| 399 |
**kwargs,
|
| 400 |
)
|
| 401 |
-
if simulation_app is not None:
|
| 402 |
-
self.simulation_app = simulation_app
|
| 403 |
|
| 404 |
def __enter__(self):
|
|
|
|
| 405 |
from isaaclab.app import AppLauncher
|
| 406 |
|
| 407 |
if not hasattr(self, "simulation_app"):
|
| 408 |
-
launch_args
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
| 414 |
self.app_launcher = AppLauncher(launch_args)
|
| 415 |
self.simulation_app = self.app_launcher.app
|
| 416 |
|
| 417 |
return self
|
| 418 |
|
| 419 |
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
| 420 |
# Close the simulation app if it was created here
|
| 421 |
-
if hasattr(self, "app_launcher"):
|
| 422 |
-
self.simulation_app.close()
|
| 423 |
-
|
| 424 |
if exc_val is not None:
|
| 425 |
logger.error(f"Exception occurred: {exc_val}.")
|
| 426 |
|
|
|
|
|
|
|
|
|
|
| 427 |
return False
|
| 428 |
|
| 429 |
def convert(self, urdf_path: str, output_file: str):
|
| 430 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
| 432 |
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
| 433 |
|
|
@@ -449,10 +584,13 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|
| 449 |
)
|
| 450 |
urdf_converter = MeshConverter(cfg)
|
| 451 |
usd_path = urdf_converter.usd_path
|
|
|
|
| 452 |
|
| 453 |
stage = Usd.Stage.Open(usd_path)
|
| 454 |
layer = stage.GetRootLayer()
|
| 455 |
with Usd.EditContext(stage, layer):
|
|
|
|
|
|
|
| 456 |
for prim in stage.Traverse():
|
| 457 |
# Change texture path to relative path.
|
| 458 |
if prim.GetName() == "material_0":
|
|
@@ -465,11 +603,9 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|
| 465 |
|
| 466 |
# Add convex decomposition collision and set ShrinkWrap.
|
| 467 |
elif prim.GetName() == "mesh":
|
| 468 |
-
approx_attr = prim.
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
"physics:approximation", Sdf.ValueTypeNames.Token
|
| 472 |
-
)
|
| 473 |
approx_attr.Set("convexDecomposition")
|
| 474 |
|
| 475 |
physx_conv_api = (
|
|
@@ -477,6 +613,15 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|
| 477 |
prim
|
| 478 |
)
|
| 479 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
| 481 |
|
| 482 |
api_schemas = prim.GetMetadata("apiSchemas")
|
|
@@ -495,15 +640,105 @@ class MeshtoUSDConverter(AssetConverterBase):
|
|
| 495 |
logger.info(f"Successfully converted {urdf_path} → {usd_path}")
|
| 496 |
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
class URDFtoUSDConverter(MeshtoUSDConverter):
|
| 499 |
-
"""
|
| 500 |
|
| 501 |
Args:
|
| 502 |
-
fix_base (bool):
|
| 503 |
-
merge_fixed_joints (bool):
|
| 504 |
-
make_instanceable (bool):
|
| 505 |
-
force_usd_conversion (bool): Force conversion to USD.
|
| 506 |
-
collision_from_visuals (bool): Generate collisions from visuals
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
"""
|
| 508 |
|
| 509 |
def __init__(
|
|
@@ -518,6 +753,19 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|
| 518 |
simulation_app=None,
|
| 519 |
**kwargs,
|
| 520 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
self.usd_parms = dict(
|
| 522 |
fix_base=fix_base,
|
| 523 |
merge_fixed_joints=merge_fixed_joints,
|
|
@@ -532,7 +780,12 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|
| 532 |
self.simulation_app = simulation_app
|
| 533 |
|
| 534 |
def convert(self, urdf_path: str, output_file: str):
|
| 535 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
| 537 |
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
| 538 |
|
|
@@ -551,11 +804,9 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|
| 551 |
with Usd.EditContext(stage, layer):
|
| 552 |
for prim in stage.Traverse():
|
| 553 |
if prim.GetName() == "collisions":
|
| 554 |
-
approx_attr = prim.
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
"physics:approximation", Sdf.ValueTypeNames.Token
|
| 558 |
-
)
|
| 559 |
approx_attr.Set("convexDecomposition")
|
| 560 |
|
| 561 |
physx_conv_api = (
|
|
@@ -563,6 +814,9 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|
| 563 |
prim
|
| 564 |
)
|
| 565 |
)
|
|
|
|
|
|
|
|
|
|
| 566 |
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
| 567 |
|
| 568 |
api_schemas = prim.GetMetadata("apiSchemas")
|
|
@@ -593,19 +847,44 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
|
|
| 593 |
|
| 594 |
|
| 595 |
class AssetConverterFactory:
|
| 596 |
-
"""Factory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
|
| 598 |
@staticmethod
|
| 599 |
def create(
|
| 600 |
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
| 601 |
) -> AssetConverterBase:
|
| 602 |
-
"""
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
converter = MeshtoMJCFConverter(**kwargs)
|
| 605 |
-
elif target_type == AssetType.
|
| 606 |
-
converter =
|
| 607 |
elif target_type == AssetType.USD and source_type == AssetType.MESH:
|
| 608 |
converter = MeshtoUSDConverter(**kwargs)
|
|
|
|
|
|
|
| 609 |
else:
|
| 610 |
raise ValueError(
|
| 611 |
f"Unsupported converter type: {source_type} -> {target_type}."
|
|
@@ -615,34 +894,48 @@ class AssetConverterFactory:
|
|
| 615 |
|
| 616 |
|
| 617 |
if __name__ == "__main__":
|
| 618 |
-
|
| 619 |
# target_asset_type = AssetType.USD
|
| 620 |
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
# source_type=AssetType.URDF,
|
| 632 |
-
# )
|
| 633 |
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
# urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
|
| 648 |
# output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
|
|
@@ -656,8 +949,21 @@ if __name__ == "__main__":
|
|
| 656 |
# with asset_converter:
|
| 657 |
# asset_converter.convert(urdf_path, output_file)
|
| 658 |
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import os
|
| 5 |
import xml.etree.ElementTree as ET
|
| 6 |
from abc import ABC, abstractmethod
|
|
|
|
| 7 |
from glob import glob
|
| 8 |
+
from shutil import copy, copytree, rmtree
|
| 9 |
|
| 10 |
import trimesh
|
| 11 |
from scipy.spatial.transform import Rotation
|
| 12 |
+
from embodied_gen.utils.enum import AssetType
|
| 13 |
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
|
|
|
| 17 |
|
| 18 |
__all__ = [
|
| 19 |
"AssetConverterFactory",
|
|
|
|
| 20 |
"MeshtoMJCFConverter",
|
| 21 |
"MeshtoUSDConverter",
|
| 22 |
"URDFtoUSDConverter",
|
| 23 |
+
"cvt_embodiedgen_asset_to_anysim",
|
| 24 |
+
"PhysicsUSDAdder",
|
| 25 |
]
|
| 26 |
|
| 27 |
|
| 28 |
+
def cvt_embodiedgen_asset_to_anysim(
|
| 29 |
+
urdf_files: list[str],
|
| 30 |
+
target_dirs: list[str],
|
| 31 |
+
target_type: AssetType,
|
| 32 |
+
source_type: AssetType,
|
| 33 |
+
overwrite: bool = False,
|
| 34 |
+
**kwargs,
|
| 35 |
+
) -> dict[str, str]:
|
| 36 |
+
"""Convert URDF files generated by EmbodiedGen into formats required by simulators.
|
| 37 |
+
|
| 38 |
+
Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
|
| 39 |
+
Converting to the `USD` format requires `isaacsim` to be installed.
|
| 40 |
+
|
| 41 |
+
Example:
|
| 42 |
+
```py
|
| 43 |
+
from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
|
| 44 |
+
from embodied_gen.utils.enum import AssetType
|
| 45 |
+
|
| 46 |
+
dst_asset_path = cvt_embodiedgen_asset_to_anysim(
|
| 47 |
+
urdf_files=[
|
| 48 |
+
"path1_to_embodiedgen_asset/asset.urdf",
|
| 49 |
+
"path2_to_embodiedgen_asset/asset.urdf",
|
| 50 |
+
],
|
| 51 |
+
target_dirs=[
|
| 52 |
+
"path1_to_target_dir/asset.usd",
|
| 53 |
+
"path2_to_target_dir/asset.usd",
|
| 54 |
+
],
|
| 55 |
+
target_type=AssetType.USD,
|
| 56 |
+
source_type=AssetType.MESH,
|
| 57 |
+
)
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
urdf_files (list[str]): List of URDF file paths.
|
| 62 |
+
target_dirs (list[str]): List of target directories.
|
| 63 |
+
target_type (AssetType): Target asset type.
|
| 64 |
+
source_type (AssetType): Source asset type.
|
| 65 |
+
overwrite (bool, optional): Overwrite existing files.
|
| 66 |
+
**kwargs: Additional converter arguments.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
dict[str, str]: Mapping from URDF file to converted asset file.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
if isinstance(urdf_files, str):
|
| 73 |
+
urdf_files = [urdf_files]
|
| 74 |
+
if isinstance(target_dirs, str):
|
| 75 |
+
urdf_files = [target_dirs]
|
| 76 |
+
|
| 77 |
+
# If the target type is URDF, no conversion is needed.
|
| 78 |
+
if target_type == AssetType.URDF:
|
| 79 |
+
return {key: key for key in urdf_files}
|
| 80 |
+
|
| 81 |
+
asset_converter = AssetConverterFactory.create(
|
| 82 |
+
target_type=target_type,
|
| 83 |
+
source_type=source_type,
|
| 84 |
+
**kwargs,
|
| 85 |
+
)
|
| 86 |
+
asset_paths = dict()
|
| 87 |
+
|
| 88 |
+
with asset_converter:
|
| 89 |
+
for urdf_file, target_dir in zip(urdf_files, target_dirs):
|
| 90 |
+
filename = os.path.basename(urdf_file).replace(".urdf", "")
|
| 91 |
+
if target_type == AssetType.MJCF:
|
| 92 |
+
target_file = f"{target_dir}/{filename}.xml"
|
| 93 |
+
elif target_type == AssetType.USD:
|
| 94 |
+
target_file = f"{target_dir}/{filename}.usd"
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError(
|
| 97 |
+
f"Target type {target_type} not supported."
|
| 98 |
+
)
|
| 99 |
+
if not os.path.exists(target_file) or overwrite:
|
| 100 |
+
asset_converter.convert(urdf_file, target_file)
|
| 101 |
+
|
| 102 |
+
asset_paths[urdf_file] = target_file
|
| 103 |
|
| 104 |
+
return asset_paths
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
class AssetConverterBase(ABC):
|
| 108 |
+
"""Abstract base class for asset converters.
|
| 109 |
+
|
| 110 |
+
Provides context management and mesh transformation utilities.
|
| 111 |
+
"""
|
| 112 |
|
| 113 |
@abstractmethod
|
| 114 |
def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
|
| 115 |
+
"""Convert an asset file.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
urdf_path (str): Path to input URDF file.
|
| 119 |
+
output_path (str): Path to output file.
|
| 120 |
+
**kwargs: Additional arguments.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
str: Path to converted asset.
|
| 124 |
+
"""
|
| 125 |
pass
|
| 126 |
|
| 127 |
def transform_mesh(
|
| 128 |
self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
|
| 129 |
) -> None:
|
| 130 |
+
"""Apply transform to mesh based on URDF origin element.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
input_mesh (str): Path to input mesh.
|
| 134 |
+
output_mesh (str): Path to output mesh.
|
| 135 |
+
mesh_origin (ET.Element): Origin element from URDF.
|
| 136 |
+
"""
|
| 137 |
+
mesh = trimesh.load(input_mesh, group_material=False)
|
| 138 |
rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
|
| 139 |
rotation = Rotation.from_euler("xyz", rpy, degrees=False)
|
| 140 |
offset = list(map(float, mesh_origin.get("xyz").split(" ")))
|
|
|
|
|
|
|
| 141 |
os.makedirs(os.path.dirname(output_mesh), exist_ok=True)
|
| 142 |
+
|
| 143 |
+
if isinstance(mesh, trimesh.Scene):
|
| 144 |
+
combined = trimesh.Scene()
|
| 145 |
+
for mesh_part in mesh.geometry.values():
|
| 146 |
+
mesh_part.vertices = (
|
| 147 |
+
mesh_part.vertices @ rotation.as_matrix().T
|
| 148 |
+
) + offset
|
| 149 |
+
combined.add_geometry(mesh_part)
|
| 150 |
+
_ = combined.export(output_mesh)
|
| 151 |
+
else:
|
| 152 |
+
mesh.vertices = (mesh.vertices @ rotation.as_matrix().T) + offset
|
| 153 |
+
_ = mesh.export(output_mesh)
|
| 154 |
|
| 155 |
return
|
| 156 |
|
| 157 |
def __enter__(self):
|
| 158 |
+
"""Context manager entry."""
|
| 159 |
return self
|
| 160 |
|
| 161 |
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 162 |
+
"""Context manager exit."""
|
| 163 |
return False
|
| 164 |
|
| 165 |
|
| 166 |
class MeshtoMJCFConverter(AssetConverterBase):
|
| 167 |
+
"""Converts mesh-based URDF files to MJCF format.
|
| 168 |
+
|
| 169 |
+
Handles geometry, materials, and asset copying.
|
| 170 |
+
"""
|
| 171 |
|
| 172 |
def __init__(
|
| 173 |
self,
|
|
|
|
| 176 |
self.kwargs = kwargs
|
| 177 |
|
| 178 |
def _copy_asset_file(self, src: str, dst: str) -> None:
|
| 179 |
+
"""Copies asset file if not already present.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
src (str): Source file path.
|
| 183 |
+
dst (str): Destination file path.
|
| 184 |
+
"""
|
| 185 |
if os.path.exists(dst):
|
| 186 |
return
|
| 187 |
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
|
|
|
| 199 |
material: ET.Element | None = None,
|
| 200 |
is_collision: bool = False,
|
| 201 |
) -> None:
|
| 202 |
+
"""Adds geometry to MJCF body from URDF link.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
mujoco_element (ET.Element): MJCF asset element.
|
| 206 |
+
link (ET.Element): URDF link element.
|
| 207 |
+
body (ET.Element): MJCF body element.
|
| 208 |
+
tag (str): Tag name ("visual" or "collision").
|
| 209 |
+
input_dir (str): Input directory.
|
| 210 |
+
output_dir (str): Output directory.
|
| 211 |
+
mesh_name (str): Mesh name.
|
| 212 |
+
material (ET.Element, optional): Material element.
|
| 213 |
+
is_collision (bool, optional): If True, treat as collision geometry.
|
| 214 |
+
"""
|
| 215 |
element = link.find(tag)
|
| 216 |
geometry = element.find("geometry")
|
| 217 |
mesh = geometry.find("mesh")
|
| 218 |
filename = mesh.get("filename")
|
| 219 |
scale = mesh.get("scale", "1.0 1.0 1.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
input_mesh = f"{input_dir}/{filename}"
|
| 221 |
output_mesh = f"{output_dir}/{filename}"
|
| 222 |
+
self._copy_asset_file(input_mesh, output_mesh)
|
| 223 |
+
|
| 224 |
mesh_origin = element.find("origin")
|
| 225 |
if mesh_origin is not None:
|
| 226 |
self.transform_mesh(input_mesh, output_mesh, mesh_origin)
|
| 227 |
|
|
|
|
|
|
|
|
|
|
| 228 |
if is_collision:
|
| 229 |
+
mesh_parts = trimesh.load(
|
| 230 |
+
output_mesh, group_material=False, force="scene"
|
| 231 |
+
)
|
| 232 |
+
mesh_parts = mesh_parts.geometry.values()
|
| 233 |
+
else:
|
| 234 |
+
mesh_parts = [trimesh.load(output_mesh, force="mesh")]
|
| 235 |
+
for idx, mesh_part in enumerate(mesh_parts):
|
| 236 |
+
if is_collision:
|
| 237 |
+
idx_mesh_name = f"{mesh_name}_{idx}"
|
| 238 |
+
base, ext = os.path.splitext(filename)
|
| 239 |
+
idx_filename = f"{base}_{idx}{ext}"
|
| 240 |
+
base_outdir = os.path.dirname(output_mesh)
|
| 241 |
+
mesh_part.export(os.path.join(base_outdir, '..', idx_filename))
|
| 242 |
+
geom_attrs = {
|
| 243 |
+
"contype": "1",
|
| 244 |
+
"conaffinity": "1",
|
| 245 |
+
"rgba": "1 1 1 0",
|
| 246 |
+
}
|
| 247 |
+
else:
|
| 248 |
+
idx_mesh_name, idx_filename = mesh_name, filename
|
| 249 |
+
geom_attrs = {"contype": "0", "conaffinity": "0"}
|
| 250 |
+
|
| 251 |
+
ET.SubElement(
|
| 252 |
+
mujoco_element,
|
| 253 |
+
"mesh",
|
| 254 |
+
name=idx_mesh_name,
|
| 255 |
+
file=idx_filename,
|
| 256 |
+
scale=scale,
|
| 257 |
+
)
|
| 258 |
+
geom = ET.SubElement(body, "geom", type="mesh", mesh=idx_mesh_name)
|
| 259 |
+
geom.attrib.update(geom_attrs)
|
| 260 |
+
if material is not None:
|
| 261 |
+
geom.set("material", material.get("name"))
|
| 262 |
|
| 263 |
def add_materials(
|
| 264 |
self,
|
|
|
|
| 270 |
name: str,
|
| 271 |
reflectance: float = 0.2,
|
| 272 |
) -> ET.Element:
|
| 273 |
+
"""Adds materials to MJCF asset from URDF link.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
mujoco_element (ET.Element): MJCF asset element.
|
| 277 |
+
link (ET.Element): URDF link element.
|
| 278 |
+
tag (str): Tag name.
|
| 279 |
+
input_dir (str): Input directory.
|
| 280 |
+
output_dir (str): Output directory.
|
| 281 |
+
name (str): Material name.
|
| 282 |
+
reflectance (float, optional): Reflectance value.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
ET.Element: Material element.
|
| 286 |
+
"""
|
| 287 |
element = link.find(tag)
|
| 288 |
geometry = element.find("geometry")
|
| 289 |
mesh = geometry.find("mesh")
|
| 290 |
filename = mesh.get("filename")
|
| 291 |
dirname = os.path.dirname(filename)
|
| 292 |
+
material = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
for path in glob(f"{input_dir}/{dirname}/*.png"):
|
| 294 |
file_name = os.path.basename(path)
|
| 295 |
+
if "keep_materials" in self.kwargs:
|
| 296 |
+
find_flag = False
|
| 297 |
+
for keep_key in self.kwargs["keep_materials"]:
|
| 298 |
+
if keep_key in file_name.lower():
|
| 299 |
+
find_flag = True
|
| 300 |
+
if find_flag is False:
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
self._copy_asset_file(
|
| 304 |
path,
|
| 305 |
f"{output_dir}/{dirname}/{file_name}",
|
| 306 |
)
|
| 307 |
+
texture_name = f"texture_{name}_{os.path.splitext(file_name)[0]}"
|
| 308 |
+
material = ET.SubElement(
|
| 309 |
+
mujoco_element,
|
| 310 |
+
"material",
|
| 311 |
+
name=f"material_{name}",
|
| 312 |
+
texture=texture_name,
|
| 313 |
+
reflectance=str(reflectance),
|
| 314 |
+
)
|
| 315 |
ET.SubElement(
|
| 316 |
mujoco_element,
|
| 317 |
"texture",
|
| 318 |
+
name=texture_name,
|
| 319 |
type="2d",
|
| 320 |
file=f"{dirname}/{file_name}",
|
| 321 |
)
|
|
|
|
| 323 |
return material
|
| 324 |
|
| 325 |
def convert(self, urdf_path: str, mjcf_path: str):
|
| 326 |
+
"""Converts a URDF file to MJCF format.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
urdf_path (str): Path to URDF file.
|
| 330 |
+
mjcf_path (str): Path to output MJCF file.
|
| 331 |
+
"""
|
| 332 |
tree = ET.parse(urdf_path)
|
| 333 |
root = tree.getroot()
|
| 334 |
|
|
|
|
| 352 |
output_dir,
|
| 353 |
name=str(idx),
|
| 354 |
)
|
| 355 |
+
joint = ET.SubElement(body, "joint", attrib={"type": "free"})
|
| 356 |
self.add_geometry(
|
| 357 |
mujoco_asset,
|
| 358 |
link,
|
|
|
|
| 382 |
|
| 383 |
|
| 384 |
class URDFtoMJCFConverter(MeshtoMJCFConverter):
|
| 385 |
+
"""Converts URDF files with joints to MJCF format, handling joint transformations.
|
| 386 |
|
| 387 |
+
Handles fixed joints and hierarchical body structure.
|
| 388 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
+
def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
|
| 391 |
+
"""Converts a URDF file with joints to MJCF format.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
+
Args:
|
| 394 |
+
urdf_path (str): Path to URDF file.
|
| 395 |
+
mjcf_path (str): Path to output MJCF file.
|
| 396 |
+
**kwargs: Additional arguments.
|
| 397 |
|
| 398 |
+
Returns:
|
| 399 |
+
str: Path to converted MJCF file.
|
| 400 |
+
"""
|
| 401 |
tree = ET.parse(urdf_path)
|
| 402 |
root = tree.getroot()
|
| 403 |
|
|
|
|
| 410 |
output_dir = os.path.dirname(mjcf_path)
|
| 411 |
os.makedirs(output_dir, exist_ok=True)
|
| 412 |
|
|
|
|
| 413 |
body_dict = {}
|
|
|
|
|
|
|
| 414 |
for idx, link in enumerate(root.findall("link")):
|
| 415 |
link_name = link.get("name", f"unnamed_link_{idx}")
|
| 416 |
body = ET.SubElement(mujoco_worldbody, "body", name=link_name)
|
| 417 |
body_dict[link_name] = body
|
| 418 |
+
if link.find("visual") is not None:
|
|
|
|
|
|
|
|
|
|
| 419 |
material = self.add_materials(
|
| 420 |
mujoco_asset,
|
| 421 |
link,
|
|
|
|
| 434 |
f"visual_mesh_{idx}",
|
| 435 |
material,
|
| 436 |
)
|
| 437 |
+
if link.find("collision") is not None:
|
|
|
|
|
|
|
| 438 |
self.add_geometry(
|
| 439 |
mujoco_asset,
|
| 440 |
link,
|
|
|
|
| 450 |
for joint in root.findall("joint"):
|
| 451 |
joint_type = joint.get("type")
|
| 452 |
if joint_type != "fixed":
|
| 453 |
+
logger.warning("Only support fixed joints in conversion now.")
|
|
|
|
|
|
|
| 454 |
continue
|
| 455 |
|
| 456 |
parent_link = joint.find("parent").get("link")
|
| 457 |
child_link = joint.find("child").get("link")
|
| 458 |
origin = joint.find("origin")
|
|
|
|
| 459 |
if parent_link not in body_dict or child_link not in body_dict:
|
| 460 |
logger.warning(
|
| 461 |
f"Parent or child link not found for joint: {joint.get('name')}"
|
| 462 |
)
|
| 463 |
continue
|
| 464 |
|
|
|
|
| 465 |
child_body = body_dict[child_link]
|
| 466 |
mujoco_worldbody.remove(child_body)
|
| 467 |
parent_body = body_dict[parent_link]
|
| 468 |
parent_body.append(child_body)
|
|
|
|
|
|
|
| 469 |
if origin is not None:
|
| 470 |
xyz = origin.get("xyz", "0 0 0")
|
| 471 |
rpy = origin.get("rpy", "0 0 0")
|
| 472 |
child_body.set("pos", xyz)
|
| 473 |
+
child_body.set("euler", rpy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
tree = ET.ElementTree(mujoco_struct)
|
| 476 |
ET.indent(tree, space=" ", level=0)
|
|
|
|
| 481 |
|
| 482 |
|
| 483 |
class MeshtoUSDConverter(AssetConverterBase):
|
| 484 |
+
"""Converts mesh-based URDF files to USD format.
|
| 485 |
+
|
| 486 |
+
Adds physics APIs and post-processes collision meshes.
|
| 487 |
+
"""
|
| 488 |
|
| 489 |
DEFAULT_BIND_APIS = [
|
| 490 |
"MaterialBindingAPI",
|
| 491 |
"PhysicsMeshCollisionAPI",
|
| 492 |
+
"PhysxConvexDecompositionCollisionAPI",
|
| 493 |
"PhysicsCollisionAPI",
|
| 494 |
"PhysxCollisionAPI",
|
| 495 |
"PhysicsMassAPI",
|
|
|
|
| 504 |
simulation_app=None,
|
| 505 |
**kwargs,
|
| 506 |
):
|
| 507 |
+
"""Initializes the converter.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
force_usd_conversion (bool, optional): Force USD conversion.
|
| 511 |
+
make_instanceable (bool, optional): Make prims instanceable.
|
| 512 |
+
simulation_app (optional): Simulation app instance.
|
| 513 |
+
**kwargs: Additional arguments.
|
| 514 |
+
"""
|
| 515 |
+
if simulation_app is not None:
|
| 516 |
+
self.simulation_app = simulation_app
|
| 517 |
+
|
| 518 |
+
self.exit_close = kwargs.pop("exit_close", True)
|
| 519 |
+
self.physx_max_convex_hulls = kwargs.pop("physx_max_convex_hulls", 32)
|
| 520 |
+
self.physx_max_vertices = kwargs.pop("physx_max_vertices", 16)
|
| 521 |
+
self.physx_max_voxel_res = kwargs.pop("physx_max_voxel_res", 10000)
|
| 522 |
+
|
| 523 |
self.usd_parms = dict(
|
| 524 |
force_usd_conversion=force_usd_conversion,
|
| 525 |
make_instanceable=make_instanceable,
|
| 526 |
**kwargs,
|
| 527 |
)
|
|
|
|
|
|
|
| 528 |
|
| 529 |
def __enter__(self):
|
| 530 |
+
"""Context manager entry, launches simulation app if needed."""
|
| 531 |
from isaaclab.app import AppLauncher
|
| 532 |
|
| 533 |
if not hasattr(self, "simulation_app"):
|
| 534 |
+
if "launch_args" not in self.usd_parms:
|
| 535 |
+
launch_args = dict(
|
| 536 |
+
headless=True,
|
| 537 |
+
no_splash=True,
|
| 538 |
+
fast_shutdown=True,
|
| 539 |
+
disable_gpu=True,
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
launch_args = self.usd_parms.pop("launch_args")
|
| 543 |
self.app_launcher = AppLauncher(launch_args)
|
| 544 |
self.simulation_app = self.app_launcher.app
|
| 545 |
|
| 546 |
return self
|
| 547 |
|
| 548 |
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 549 |
+
"""Context manager exit, closes simulation app if created."""
|
| 550 |
# Close the simulation app if it was created here
|
|
|
|
|
|
|
|
|
|
| 551 |
if exc_val is not None:
|
| 552 |
logger.error(f"Exception occurred: {exc_val}.")
|
| 553 |
|
| 554 |
+
if hasattr(self, "app_launcher") and self.exit_close:
|
| 555 |
+
self.simulation_app.close()
|
| 556 |
+
|
| 557 |
return False
|
| 558 |
|
| 559 |
def convert(self, urdf_path: str, output_file: str):
|
| 560 |
+
"""Converts a URDF file to USD and post-processes collision meshes.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
urdf_path (str): Path to URDF file.
|
| 564 |
+
output_file (str): Path to output USD file.
|
| 565 |
+
"""
|
| 566 |
from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
|
| 567 |
from pxr import PhysxSchema, Sdf, Usd, UsdShade
|
| 568 |
|
|
|
|
| 584 |
)
|
| 585 |
urdf_converter = MeshConverter(cfg)
|
| 586 |
usd_path = urdf_converter.usd_path
|
| 587 |
+
rmtree(os.path.dirname(output_mesh))
|
| 588 |
|
| 589 |
stage = Usd.Stage.Open(usd_path)
|
| 590 |
layer = stage.GetRootLayer()
|
| 591 |
with Usd.EditContext(stage, layer):
|
| 592 |
+
base_prim = stage.GetPseudoRoot().GetChildren()[0]
|
| 593 |
+
base_prim.SetMetadata("kind", "component")
|
| 594 |
for prim in stage.Traverse():
|
| 595 |
# Change texture path to relative path.
|
| 596 |
if prim.GetName() == "material_0":
|
|
|
|
| 603 |
|
| 604 |
# Add convex decomposition collision and set ShrinkWrap.
|
| 605 |
elif prim.GetName() == "mesh":
|
| 606 |
+
approx_attr = prim.CreateAttribute(
|
| 607 |
+
"physics:approximation", Sdf.ValueTypeNames.Token
|
| 608 |
+
)
|
|
|
|
|
|
|
| 609 |
approx_attr.Set("convexDecomposition")
|
| 610 |
|
| 611 |
physx_conv_api = (
|
|
|
|
| 613 |
prim
|
| 614 |
)
|
| 615 |
)
|
| 616 |
+
physx_conv_api.GetMaxConvexHullsAttr().Set(
|
| 617 |
+
self.physx_max_convex_hulls
|
| 618 |
+
)
|
| 619 |
+
physx_conv_api.GetHullVertexLimitAttr().Set(
|
| 620 |
+
self.physx_max_vertices
|
| 621 |
+
)
|
| 622 |
+
physx_conv_api.GetVoxelResolutionAttr().Set(
|
| 623 |
+
self.physx_max_voxel_res
|
| 624 |
+
)
|
| 625 |
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
| 626 |
|
| 627 |
api_schemas = prim.GetMetadata("apiSchemas")
|
|
|
|
| 640 |
logger.info(f"Successfully converted {urdf_path} → {usd_path}")
|
| 641 |
|
| 642 |
|
| 643 |
+
class PhysicsUSDAdder(MeshtoUSDConverter):
|
| 644 |
+
"""Adds physics APIs and collision properties to USD assets.
|
| 645 |
+
|
| 646 |
+
Useful for post-processing USD files for simulation.
|
| 647 |
+
"""
|
| 648 |
+
|
| 649 |
+
DEFAULT_BIND_APIS = [
|
| 650 |
+
"MaterialBindingAPI",
|
| 651 |
+
"PhysicsMeshCollisionAPI",
|
| 652 |
+
"PhysxConvexDecompositionCollisionAPI",
|
| 653 |
+
"PhysicsCollisionAPI",
|
| 654 |
+
"PhysxCollisionAPI",
|
| 655 |
+
"PhysicsRigidBodyAPI",
|
| 656 |
+
]
|
| 657 |
+
|
| 658 |
+
def convert(self, usd_path: str, output_file: str = None):
|
| 659 |
+
"""Adds physics APIs and collision properties to a USD file.
|
| 660 |
+
|
| 661 |
+
Args:
|
| 662 |
+
usd_path (str): Path to input USD file.
|
| 663 |
+
output_file (str, optional): Path to output USD file.
|
| 664 |
+
"""
|
| 665 |
+
from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
|
| 666 |
+
|
| 667 |
+
if output_file is None:
|
| 668 |
+
output_file = usd_path
|
| 669 |
+
else:
|
| 670 |
+
dst_dir = os.path.dirname(output_file)
|
| 671 |
+
src_dir = os.path.dirname(usd_path)
|
| 672 |
+
copytree(src_dir, dst_dir, dirs_exist_ok=True)
|
| 673 |
+
|
| 674 |
+
stage = Usd.Stage.Open(output_file)
|
| 675 |
+
layer = stage.GetRootLayer()
|
| 676 |
+
with Usd.EditContext(stage, layer):
|
| 677 |
+
for prim in stage.Traverse():
|
| 678 |
+
if prim.IsA(UsdGeom.Xform):
|
| 679 |
+
for child in prim.GetChildren():
|
| 680 |
+
if not child.IsA(UsdGeom.Mesh):
|
| 681 |
+
continue
|
| 682 |
+
|
| 683 |
+
# Skip the lightfactory in Infinigen
|
| 684 |
+
if "lightfactory" in prim.GetName().lower():
|
| 685 |
+
continue
|
| 686 |
+
|
| 687 |
+
approx_attr = prim.CreateAttribute(
|
| 688 |
+
"physics:approximation", Sdf.ValueTypeNames.Token
|
| 689 |
+
)
|
| 690 |
+
approx_attr.Set("convexDecomposition")
|
| 691 |
+
|
| 692 |
+
physx_conv_api = PhysxSchema.PhysxConvexDecompositionCollisionAPI.Apply(
|
| 693 |
+
prim
|
| 694 |
+
)
|
| 695 |
+
physx_conv_api.GetMaxConvexHullsAttr().Set(
|
| 696 |
+
self.physx_max_convex_hulls
|
| 697 |
+
)
|
| 698 |
+
physx_conv_api.GetHullVertexLimitAttr().Set(
|
| 699 |
+
self.physx_max_vertices
|
| 700 |
+
)
|
| 701 |
+
physx_conv_api.GetVoxelResolutionAttr().Set(
|
| 702 |
+
self.physx_max_voxel_res
|
| 703 |
+
)
|
| 704 |
+
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
| 705 |
+
|
| 706 |
+
rigid_body_api = UsdPhysics.RigidBodyAPI.Apply(prim)
|
| 707 |
+
rigid_body_api.CreateKinematicEnabledAttr().Set(True)
|
| 708 |
+
if prim.GetAttribute("physics:mass"):
|
| 709 |
+
prim.RemoveProperty("physics:mass")
|
| 710 |
+
if prim.GetAttribute("physics:velocity"):
|
| 711 |
+
prim.RemoveProperty("physics:velocity")
|
| 712 |
+
|
| 713 |
+
api_schemas = prim.GetMetadata("apiSchemas")
|
| 714 |
+
if api_schemas is None:
|
| 715 |
+
api_schemas = Sdf.TokenListOp()
|
| 716 |
+
|
| 717 |
+
api_list = list(api_schemas.GetAddedOrExplicitItems())
|
| 718 |
+
for api in self.DEFAULT_BIND_APIS:
|
| 719 |
+
if api not in api_list:
|
| 720 |
+
api_list.append(api)
|
| 721 |
+
|
| 722 |
+
api_schemas.appendedItems = api_list
|
| 723 |
+
prim.SetMetadata("apiSchemas", api_schemas)
|
| 724 |
+
|
| 725 |
+
layer.Save()
|
| 726 |
+
logger.info(f"Successfully converted {usd_path} to {output_file}")
|
| 727 |
+
|
| 728 |
+
|
| 729 |
class URDFtoUSDConverter(MeshtoUSDConverter):
|
| 730 |
+
"""Converts URDF files to USD format.
|
| 731 |
|
| 732 |
Args:
|
| 733 |
+
fix_base (bool, optional): Fix the base link.
|
| 734 |
+
merge_fixed_joints (bool, optional): Merge fixed joints.
|
| 735 |
+
make_instanceable (bool, optional): Make prims instanceable.
|
| 736 |
+
force_usd_conversion (bool, optional): Force conversion to USD.
|
| 737 |
+
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
| 738 |
+
joint_drive (optional): Joint drive configuration.
|
| 739 |
+
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
| 740 |
+
simulation_app (optional): Simulation app instance.
|
| 741 |
+
**kwargs: Additional arguments.
|
| 742 |
"""
|
| 743 |
|
| 744 |
def __init__(
|
|
|
|
| 753 |
simulation_app=None,
|
| 754 |
**kwargs,
|
| 755 |
):
|
| 756 |
+
"""Initializes the converter.
|
| 757 |
+
|
| 758 |
+
Args:
|
| 759 |
+
fix_base (bool, optional): Fix the base link.
|
| 760 |
+
merge_fixed_joints (bool, optional): Merge fixed joints.
|
| 761 |
+
make_instanceable (bool, optional): Make prims instanceable.
|
| 762 |
+
force_usd_conversion (bool, optional): Force conversion to USD.
|
| 763 |
+
collision_from_visuals (bool, optional): Generate collisions from visuals.
|
| 764 |
+
joint_drive (optional): Joint drive configuration.
|
| 765 |
+
rotate_wxyz (tuple[float], optional): Quaternion for rotation.
|
| 766 |
+
simulation_app (optional): Simulation app instance.
|
| 767 |
+
**kwargs: Additional arguments.
|
| 768 |
+
"""
|
| 769 |
self.usd_parms = dict(
|
| 770 |
fix_base=fix_base,
|
| 771 |
merge_fixed_joints=merge_fixed_joints,
|
|
|
|
| 780 |
self.simulation_app = simulation_app
|
| 781 |
|
| 782 |
def convert(self, urdf_path: str, output_file: str):
|
| 783 |
+
"""Converts a URDF file to USD and post-processes collision meshes.
|
| 784 |
+
|
| 785 |
+
Args:
|
| 786 |
+
urdf_path (str): Path to URDF file.
|
| 787 |
+
output_file (str): Path to output USD file.
|
| 788 |
+
"""
|
| 789 |
from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
|
| 790 |
from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
|
| 791 |
|
|
|
|
| 804 |
with Usd.EditContext(stage, layer):
|
| 805 |
for prim in stage.Traverse():
|
| 806 |
if prim.GetName() == "collisions":
|
| 807 |
+
approx_attr = prim.CreateAttribute(
|
| 808 |
+
"physics:approximation", Sdf.ValueTypeNames.Token
|
| 809 |
+
)
|
|
|
|
|
|
|
| 810 |
approx_attr.Set("convexDecomposition")
|
| 811 |
|
| 812 |
physx_conv_api = (
|
|
|
|
| 814 |
prim
|
| 815 |
)
|
| 816 |
)
|
| 817 |
+
physx_conv_api.GetMaxConvexHullsAttr().Set(32)
|
| 818 |
+
physx_conv_api.GetHullVertexLimitAttr().Set(16)
|
| 819 |
+
physx_conv_api.GetVoxelResolutionAttr().Set(10000)
|
| 820 |
physx_conv_api.GetShrinkWrapAttr().Set(True)
|
| 821 |
|
| 822 |
api_schemas = prim.GetMetadata("apiSchemas")
|
|
|
|
| 847 |
|
| 848 |
|
| 849 |
class AssetConverterFactory:
|
| 850 |
+
"""Factory for creating asset converters based on target and source types.
|
| 851 |
+
|
| 852 |
+
Example:
|
| 853 |
+
```py
|
| 854 |
+
from embodied_gen.data.asset_converter import AssetConverterFactory
|
| 855 |
+
from embodied_gen.utils.enum import AssetType
|
| 856 |
+
|
| 857 |
+
converter = AssetConverterFactory.create(
|
| 858 |
+
target_type=AssetType.USD, source_type=AssetType.MESH
|
| 859 |
+
)
|
| 860 |
+
with converter:
|
| 861 |
+
for urdf_path, output_file in zip(urdf_paths, output_files):
|
| 862 |
+
converter.convert(urdf_path, output_file)
|
| 863 |
+
```
|
| 864 |
+
"""
|
| 865 |
|
| 866 |
@staticmethod
|
| 867 |
def create(
|
| 868 |
target_type: AssetType, source_type: AssetType = "urdf", **kwargs
|
| 869 |
) -> AssetConverterBase:
|
| 870 |
+
"""Creates an asset converter instance.
|
| 871 |
+
|
| 872 |
+
Args:
|
| 873 |
+
target_type (AssetType): Target asset type.
|
| 874 |
+
source_type (AssetType, optional): Source asset type.
|
| 875 |
+
**kwargs: Additional arguments.
|
| 876 |
+
|
| 877 |
+
Returns:
|
| 878 |
+
AssetConverterBase: Converter instance.
|
| 879 |
+
"""
|
| 880 |
+
if target_type == AssetType.MJCF and source_type == AssetType.MESH:
|
| 881 |
converter = MeshtoMJCFConverter(**kwargs)
|
| 882 |
+
elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
|
| 883 |
+
converter = URDFtoMJCFConverter(**kwargs)
|
| 884 |
elif target_type == AssetType.USD and source_type == AssetType.MESH:
|
| 885 |
converter = MeshtoUSDConverter(**kwargs)
|
| 886 |
+
elif target_type == AssetType.USD and source_type == AssetType.URDF:
|
| 887 |
+
converter = URDFtoUSDConverter(**kwargs)
|
| 888 |
else:
|
| 889 |
raise ValueError(
|
| 890 |
f"Unsupported converter type: {source_type} -> {target_type}."
|
|
|
|
| 894 |
|
| 895 |
|
| 896 |
if __name__ == "__main__":
|
| 897 |
+
target_asset_type = AssetType.MJCF
|
| 898 |
# target_asset_type = AssetType.USD
|
| 899 |
|
| 900 |
+
urdf_paths = [
|
| 901 |
+
'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf',
|
| 902 |
+
'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf',
|
| 903 |
+
'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf',
|
| 904 |
+
'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf',
|
| 905 |
+
'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf',
|
| 906 |
+
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf",
|
| 907 |
+
'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf',
|
| 908 |
+
'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf',
|
| 909 |
+
]
|
|
|
|
|
|
|
| 910 |
|
| 911 |
+
if target_asset_type == AssetType.MJCF:
|
| 912 |
+
output_files = [
|
| 913 |
+
"outputs/embodiedgen_assets/demo_assets/demo_assets/remote_control/mjcf/remote_control.xml",
|
| 914 |
+
]
|
| 915 |
+
asset_converter = AssetConverterFactory.create(
|
| 916 |
+
target_type=AssetType.MJCF,
|
| 917 |
+
source_type=AssetType.MESH,
|
| 918 |
+
)
|
| 919 |
|
| 920 |
+
elif target_asset_type == AssetType.USD:
|
| 921 |
+
output_files = [
|
| 922 |
+
'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd',
|
| 923 |
+
'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd',
|
| 924 |
+
'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd',
|
| 925 |
+
'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd',
|
| 926 |
+
'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd',
|
| 927 |
+
"outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd",
|
| 928 |
+
'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd',
|
| 929 |
+
'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd',
|
| 930 |
+
]
|
| 931 |
+
asset_converter = AssetConverterFactory.create(
|
| 932 |
+
target_type=AssetType.USD,
|
| 933 |
+
source_type=AssetType.MESH,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
with asset_converter:
|
| 937 |
+
for urdf_path, output_file in zip(urdf_paths, output_files):
|
| 938 |
+
asset_converter.convert(urdf_path, output_file)
|
| 939 |
|
| 940 |
# urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
|
| 941 |
# output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
|
|
|
|
| 949 |
# with asset_converter:
|
| 950 |
# asset_converter.convert(urdf_path, output_file)
|
| 951 |
|
| 952 |
+
# # Convert infinigen urdf to mjcf
|
| 953 |
+
# urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf"
|
| 954 |
+
# output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml"
|
| 955 |
+
# asset_converter = AssetConverterFactory.create(
|
| 956 |
+
# target_type=AssetType.MJCF,
|
| 957 |
+
# source_type=AssetType.URDF,
|
| 958 |
+
# keep_materials=["diffuse"],
|
| 959 |
+
# )
|
| 960 |
+
# with asset_converter:
|
| 961 |
+
# asset_converter.convert(urdf_path, output_file)
|
| 962 |
+
|
| 963 |
+
# # Convert infinigen usdc to physics usdc
|
| 964 |
+
# converter = PhysicsUSDAdder()
|
| 965 |
+
# with converter:
|
| 966 |
+
# converter.convert(
|
| 967 |
+
# usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc",
|
| 968 |
+
# output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc",
|
| 969 |
+
# )
|
embodied_gen/data/backproject.py
CHANGED
|
@@ -34,6 +34,7 @@ from embodied_gen.data.utils import (
|
|
| 34 |
CameraSetting,
|
| 35 |
get_images_from_grid,
|
| 36 |
init_kal_camera,
|
|
|
|
| 37 |
normalize_vertices_array,
|
| 38 |
post_process_texture,
|
| 39 |
save_mesh_with_mtl,
|
|
@@ -306,28 +307,6 @@ class TextureBaker(object):
|
|
| 306 |
raise ValueError(f"Unknown mode: {mode}")
|
| 307 |
|
| 308 |
|
| 309 |
-
def kaolin_to_opencv_view(raw_matrix):
|
| 310 |
-
R_orig = raw_matrix[:, :3, :3]
|
| 311 |
-
t_orig = raw_matrix[:, :3, 3]
|
| 312 |
-
|
| 313 |
-
R_target = torch.zeros_like(R_orig)
|
| 314 |
-
R_target[:, :, 0] = R_orig[:, :, 2]
|
| 315 |
-
R_target[:, :, 1] = R_orig[:, :, 0]
|
| 316 |
-
R_target[:, :, 2] = R_orig[:, :, 1]
|
| 317 |
-
|
| 318 |
-
t_target = t_orig
|
| 319 |
-
|
| 320 |
-
target_matrix = (
|
| 321 |
-
torch.eye(4, device=raw_matrix.device)
|
| 322 |
-
.unsqueeze(0)
|
| 323 |
-
.repeat(raw_matrix.size(0), 1, 1)
|
| 324 |
-
)
|
| 325 |
-
target_matrix[:, :3, :3] = R_target
|
| 326 |
-
target_matrix[:, :3, 3] = t_target
|
| 327 |
-
|
| 328 |
-
return target_matrix
|
| 329 |
-
|
| 330 |
-
|
| 331 |
def parse_args():
|
| 332 |
parser = argparse.ArgumentParser(description="Render settings")
|
| 333 |
|
|
|
|
| 34 |
CameraSetting,
|
| 35 |
get_images_from_grid,
|
| 36 |
init_kal_camera,
|
| 37 |
+
kaolin_to_opencv_view,
|
| 38 |
normalize_vertices_array,
|
| 39 |
post_process_texture,
|
| 40 |
save_mesh_with_mtl,
|
|
|
|
| 307 |
raise ValueError(f"Unknown mode: {mode}")
|
| 308 |
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
def parse_args():
|
| 311 |
parser = argparse.ArgumentParser(description="Render settings")
|
| 312 |
|
embodied_gen/data/backproject_v2.py
CHANGED
|
@@ -58,7 +58,16 @@ __all__ = [
|
|
| 58 |
def _transform_vertices(
|
| 59 |
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
| 60 |
) -> torch.Tensor:
|
| 61 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
| 63 |
if pos.size(-1) == 3:
|
| 64 |
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
|
@@ -71,7 +80,17 @@ def _transform_vertices(
|
|
| 71 |
def _bilinear_interpolation_scattering(
|
| 72 |
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
| 73 |
) -> torch.Tensor:
|
| 74 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
device = values.device
|
| 76 |
dtype = values.dtype
|
| 77 |
C = values.shape[-1]
|
|
@@ -135,7 +154,18 @@ def _texture_inpaint_smooth(
|
|
| 135 |
faces: np.ndarray,
|
| 136 |
uv_map: np.ndarray,
|
| 137 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 138 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
image_h, image_w, C = texture.shape
|
| 140 |
N = vertices.shape[0]
|
| 141 |
|
|
@@ -231,29 +261,41 @@ def _texture_inpaint_smooth(
|
|
| 231 |
class TextureBacker:
|
| 232 |
"""Texture baking pipeline for multi-view projection and fusion.
|
| 233 |
|
| 234 |
-
This class
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
back-projection, confidence-weighted texture fusion, and inpainting
|
| 238 |
-
of missing texture regions.
|
| 239 |
|
| 240 |
Args:
|
| 241 |
-
camera_params (CameraSetting): Camera intrinsics and extrinsics
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
"""
|
| 258 |
|
| 259 |
def __init__(
|
|
@@ -283,6 +325,12 @@ class TextureBacker:
|
|
| 283 |
)
|
| 284 |
|
| 285 |
def _lazy_init_render(self, camera_params, mask_thresh):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
if self.renderer is None:
|
| 287 |
camera = init_kal_camera(camera_params)
|
| 288 |
mv = camera.view_matrix() # (n 4 4) world2cam
|
|
@@ -301,6 +349,14 @@ class TextureBacker:
|
|
| 301 |
)
|
| 302 |
|
| 303 |
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
| 305 |
self.scale, self.center = scale, center
|
| 306 |
|
|
@@ -318,6 +374,16 @@ class TextureBacker:
|
|
| 318 |
scale: float = None,
|
| 319 |
center: np.ndarray = None,
|
| 320 |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
vertices = mesh.vertices.copy()
|
| 322 |
faces = mesh.faces.copy()
|
| 323 |
uv_map = mesh.visual.uv.copy()
|
|
@@ -331,6 +397,14 @@ class TextureBacker:
|
|
| 331 |
return vertices, faces, uv_map
|
| 332 |
|
| 333 |
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
depth_image_np = depth_image.cpu().numpy()
|
| 335 |
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
| 336 |
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
|
@@ -344,6 +418,16 @@ class TextureBacker:
|
|
| 344 |
def compute_enhanced_viewnormal(
|
| 345 |
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
| 346 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
| 348 |
rendered_view_normals = []
|
| 349 |
for idx in range(len(mv_mtx)):
|
|
@@ -376,6 +460,18 @@ class TextureBacker:
|
|
| 376 |
def back_project(
|
| 377 |
self, image, vis_mask, depth, normal, uv
|
| 378 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
image = np.array(image)
|
| 380 |
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
| 381 |
if image.ndim == 2:
|
|
@@ -418,6 +514,17 @@ class TextureBacker:
|
|
| 418 |
)
|
| 419 |
|
| 420 |
def _scatter_texture(self, uv, data, mask):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
def __filter_data(data, mask):
|
| 422 |
return data.view(-1, data.shape[-1])[mask]
|
| 423 |
|
|
@@ -432,6 +539,15 @@ class TextureBacker:
|
|
| 432 |
def fast_bake_texture(
|
| 433 |
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
| 434 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
channel = textures[0].shape[-1]
|
| 436 |
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
| 437 |
self.device
|
|
@@ -451,6 +567,16 @@ class TextureBacker:
|
|
| 451 |
def uv_inpaint(
|
| 452 |
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
| 453 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
if self.inpaint_smooth:
|
| 455 |
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
| 456 |
texture, mask = _texture_inpaint_smooth(
|
|
@@ -473,6 +599,15 @@ class TextureBacker:
|
|
| 473 |
colors: list[Image.Image],
|
| 474 |
mesh: trimesh.Trimesh,
|
| 475 |
) -> trimesh.Trimesh:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
| 477 |
|
| 478 |
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
|
@@ -517,7 +652,7 @@ class TextureBacker:
|
|
| 517 |
Args:
|
| 518 |
colors (list[Image.Image]): List of input view images.
|
| 519 |
mesh (trimesh.Trimesh): Input mesh to be textured.
|
| 520 |
-
output_path (str): Path to save the output textured mesh
|
| 521 |
|
| 522 |
Returns:
|
| 523 |
trimesh.Trimesh: The textured mesh with UV and texture image.
|
|
@@ -540,6 +675,11 @@ class TextureBacker:
|
|
| 540 |
|
| 541 |
|
| 542 |
def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
parser = argparse.ArgumentParser(description="Backproject texture")
|
| 544 |
parser.add_argument(
|
| 545 |
"--color_path",
|
|
@@ -636,6 +776,16 @@ def entrypoint(
|
|
| 636 |
imagesr_model: ImageRealESRGAN = None,
|
| 637 |
**kwargs,
|
| 638 |
) -> trimesh.Trimesh:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
args = parse_args()
|
| 640 |
for k, v in kwargs.items():
|
| 641 |
if hasattr(args, k) and v is not None:
|
|
|
|
| 58 |
def _transform_vertices(
|
| 59 |
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
|
| 60 |
) -> torch.Tensor:
|
| 61 |
+
"""Transforms 3D vertices using a projection matrix.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
mtx (torch.Tensor): Projection matrix.
|
| 65 |
+
pos (torch.Tensor): Vertex positions.
|
| 66 |
+
keepdim (bool, optional): If True, keeps the batch dimension.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
torch.Tensor: Transformed vertices.
|
| 70 |
+
"""
|
| 71 |
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
|
| 72 |
if pos.size(-1) == 3:
|
| 73 |
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
|
|
|
|
| 80 |
def _bilinear_interpolation_scattering(
|
| 81 |
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
|
| 82 |
) -> torch.Tensor:
|
| 83 |
+
"""Performs bilinear interpolation scattering for grid-based value accumulation.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
image_h (int): Image height.
|
| 87 |
+
image_w (int): Image width.
|
| 88 |
+
coords (torch.Tensor): Normalized coordinates.
|
| 89 |
+
values (torch.Tensor): Values to scatter.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
torch.Tensor: Interpolated grid.
|
| 93 |
+
"""
|
| 94 |
device = values.device
|
| 95 |
dtype = values.dtype
|
| 96 |
C = values.shape[-1]
|
|
|
|
| 154 |
faces: np.ndarray,
|
| 155 |
uv_map: np.ndarray,
|
| 156 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 157 |
+
"""Performs texture inpainting using vertex-based color propagation.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
texture (np.ndarray): Texture image.
|
| 161 |
+
mask (np.ndarray): Mask image.
|
| 162 |
+
vertices (np.ndarray): Mesh vertices.
|
| 163 |
+
faces (np.ndarray): Mesh faces.
|
| 164 |
+
uv_map (np.ndarray): UV coordinates.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask.
|
| 168 |
+
"""
|
| 169 |
image_h, image_w, C = texture.shape
|
| 170 |
N = vertices.shape[0]
|
| 171 |
|
|
|
|
| 261 |
class TextureBacker:
|
| 262 |
"""Texture baking pipeline for multi-view projection and fusion.
|
| 263 |
|
| 264 |
+
This class generates UV-based textures for a 3D mesh using multi-view images,
|
| 265 |
+
depth, and normal information. It includes mesh normalization, UV unwrapping,
|
| 266 |
+
visibility-aware back-projection, confidence-weighted fusion, and inpainting.
|
|
|
|
|
|
|
| 267 |
|
| 268 |
Args:
|
| 269 |
+
camera_params (CameraSetting): Camera intrinsics and extrinsics.
|
| 270 |
+
view_weights (list[float]): Weights for each view in texture fusion.
|
| 271 |
+
render_wh (tuple[int, int], optional): Intermediate rendering resolution.
|
| 272 |
+
texture_wh (tuple[int, int], optional): Output texture resolution.
|
| 273 |
+
bake_angle_thresh (int, optional): Max angle for valid projection.
|
| 274 |
+
mask_thresh (float, optional): Threshold for visibility masks.
|
| 275 |
+
smooth_texture (bool, optional): Apply post-processing to texture.
|
| 276 |
+
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
| 277 |
+
|
| 278 |
+
Example:
|
| 279 |
+
```py
|
| 280 |
+
from embodied_gen.data.backproject_v2 import TextureBacker
|
| 281 |
+
from embodied_gen.data.utils import CameraSetting
|
| 282 |
+
import trimesh
|
| 283 |
+
from PIL import Image
|
| 284 |
+
|
| 285 |
+
camera_params = CameraSetting(
|
| 286 |
+
num_images=6,
|
| 287 |
+
elevation=[20, -10],
|
| 288 |
+
distance=5,
|
| 289 |
+
resolution_hw=(2048,2048),
|
| 290 |
+
fov=math.radians(30),
|
| 291 |
+
device='cuda',
|
| 292 |
+
)
|
| 293 |
+
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
|
| 294 |
+
mesh = trimesh.load('mesh.obj')
|
| 295 |
+
images = [Image.open(f'view_{i}.png') for i in range(6)]
|
| 296 |
+
texture_backer = TextureBacker(camera_params, view_weights)
|
| 297 |
+
textured_mesh = texture_backer(images, mesh, 'output.obj')
|
| 298 |
+
```
|
| 299 |
"""
|
| 300 |
|
| 301 |
def __init__(
|
|
|
|
| 325 |
)
|
| 326 |
|
| 327 |
def _lazy_init_render(self, camera_params, mask_thresh):
|
| 328 |
+
"""Lazily initializes the renderer.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
camera_params (CameraSetting): Camera settings.
|
| 332 |
+
mask_thresh (float): Mask threshold.
|
| 333 |
+
"""
|
| 334 |
if self.renderer is None:
|
| 335 |
camera = init_kal_camera(camera_params)
|
| 336 |
mv = camera.view_matrix() # (n 4 4) world2cam
|
|
|
|
| 349 |
)
|
| 350 |
|
| 351 |
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
| 352 |
+
"""Normalizes mesh and unwraps UVs.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
mesh (trimesh.Trimesh): Input mesh.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
trimesh.Trimesh: Mesh with normalized vertices and UVs.
|
| 359 |
+
"""
|
| 360 |
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
| 361 |
self.scale, self.center = scale, center
|
| 362 |
|
|
|
|
| 374 |
scale: float = None,
|
| 375 |
center: np.ndarray = None,
|
| 376 |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 377 |
+
"""Gets mesh attributes as numpy arrays.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
mesh (trimesh.Trimesh): Input mesh.
|
| 381 |
+
scale (float, optional): Scale factor.
|
| 382 |
+
center (np.ndarray, optional): Center offset.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
tuple: (vertices, faces, uv_map)
|
| 386 |
+
"""
|
| 387 |
vertices = mesh.vertices.copy()
|
| 388 |
faces = mesh.faces.copy()
|
| 389 |
uv_map = mesh.visual.uv.copy()
|
|
|
|
| 397 |
return vertices, faces, uv_map
|
| 398 |
|
| 399 |
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
|
| 400 |
+
"""Computes edge image from depth map.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
depth_image (torch.Tensor): Depth map.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
torch.Tensor: Edge image.
|
| 407 |
+
"""
|
| 408 |
depth_image_np = depth_image.cpu().numpy()
|
| 409 |
depth_image_np = (depth_image_np * 255).astype(np.uint8)
|
| 410 |
depth_edges = cv2.Canny(depth_image_np, 30, 80)
|
|
|
|
| 418 |
def compute_enhanced_viewnormal(
|
| 419 |
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
|
| 420 |
) -> torch.Tensor:
|
| 421 |
+
"""Computes enhanced view normals for mesh faces.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
mv_mtx (torch.Tensor): View matrices.
|
| 425 |
+
vertices (torch.Tensor): Mesh vertices.
|
| 426 |
+
faces (torch.Tensor): Mesh faces.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
torch.Tensor: View normals.
|
| 430 |
+
"""
|
| 431 |
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
|
| 432 |
rendered_view_normals = []
|
| 433 |
for idx in range(len(mv_mtx)):
|
|
|
|
| 460 |
def back_project(
|
| 461 |
self, image, vis_mask, depth, normal, uv
|
| 462 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 463 |
+
"""Back-projects image and confidence to UV texture space.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
image (PIL.Image or np.ndarray): Input image.
|
| 467 |
+
vis_mask (torch.Tensor): Visibility mask.
|
| 468 |
+
depth (torch.Tensor): Depth map.
|
| 469 |
+
normal (torch.Tensor): Normal map.
|
| 470 |
+
uv (torch.Tensor): UV coordinates.
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
tuple[torch.Tensor, torch.Tensor]: Texture and confidence map.
|
| 474 |
+
"""
|
| 475 |
image = np.array(image)
|
| 476 |
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
|
| 477 |
if image.ndim == 2:
|
|
|
|
| 514 |
)
|
| 515 |
|
| 516 |
def _scatter_texture(self, uv, data, mask):
|
| 517 |
+
"""Scatters data to texture using UV coordinates and mask.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
uv (torch.Tensor): UV coordinates.
|
| 521 |
+
data (torch.Tensor): Data to scatter.
|
| 522 |
+
mask (torch.Tensor): Mask for valid pixels.
|
| 523 |
+
|
| 524 |
+
Returns:
|
| 525 |
+
torch.Tensor: Scattered texture.
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
def __filter_data(data, mask):
|
| 529 |
return data.view(-1, data.shape[-1])[mask]
|
| 530 |
|
|
|
|
| 539 |
def fast_bake_texture(
|
| 540 |
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
|
| 541 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 542 |
+
"""Fuses multiple textures and confidence maps.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
textures (list[torch.Tensor]): List of textures.
|
| 546 |
+
confidence_maps (list[torch.Tensor]): List of confidence maps.
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
tuple[torch.Tensor, torch.Tensor]: Fused texture and mask.
|
| 550 |
+
"""
|
| 551 |
channel = textures[0].shape[-1]
|
| 552 |
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
|
| 553 |
self.device
|
|
|
|
| 567 |
def uv_inpaint(
|
| 568 |
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
| 569 |
) -> np.ndarray:
|
| 570 |
+
"""Inpaints missing regions in the UV texture.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
mesh (trimesh.Trimesh): Mesh.
|
| 574 |
+
texture (np.ndarray): Texture image.
|
| 575 |
+
mask (np.ndarray): Mask image.
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
np.ndarray: Inpainted texture.
|
| 579 |
+
"""
|
| 580 |
if self.inpaint_smooth:
|
| 581 |
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
| 582 |
texture, mask = _texture_inpaint_smooth(
|
|
|
|
| 599 |
colors: list[Image.Image],
|
| 600 |
mesh: trimesh.Trimesh,
|
| 601 |
) -> trimesh.Trimesh:
|
| 602 |
+
"""Computes the fused texture for the mesh from multi-view images.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
colors (list[Image.Image]): List of view images.
|
| 606 |
+
mesh (trimesh.Trimesh): Mesh to texture.
|
| 607 |
+
|
| 608 |
+
Returns:
|
| 609 |
+
tuple[np.ndarray, np.ndarray]: Texture and mask.
|
| 610 |
+
"""
|
| 611 |
self._lazy_init_render(self.camera_params, self.mask_thresh)
|
| 612 |
|
| 613 |
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
|
|
|
|
| 652 |
Args:
|
| 653 |
colors (list[Image.Image]): List of input view images.
|
| 654 |
mesh (trimesh.Trimesh): Input mesh to be textured.
|
| 655 |
+
output_path (str): Path to save the output textured mesh.
|
| 656 |
|
| 657 |
Returns:
|
| 658 |
trimesh.Trimesh: The textured mesh with UV and texture image.
|
|
|
|
| 675 |
|
| 676 |
|
| 677 |
def parse_args():
|
| 678 |
+
"""Parses command-line arguments for texture backprojection.
|
| 679 |
+
|
| 680 |
+
Returns:
|
| 681 |
+
argparse.Namespace: Parsed arguments.
|
| 682 |
+
"""
|
| 683 |
parser = argparse.ArgumentParser(description="Backproject texture")
|
| 684 |
parser.add_argument(
|
| 685 |
"--color_path",
|
|
|
|
| 776 |
imagesr_model: ImageRealESRGAN = None,
|
| 777 |
**kwargs,
|
| 778 |
) -> trimesh.Trimesh:
|
| 779 |
+
"""Entrypoint for texture backprojection from multi-view images.
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
delight_model (DelightingModel, optional): Delighting model.
|
| 783 |
+
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
| 784 |
+
**kwargs: Additional arguments to override CLI.
|
| 785 |
+
|
| 786 |
+
Returns:
|
| 787 |
+
trimesh.Trimesh: Textured mesh.
|
| 788 |
+
"""
|
| 789 |
args = parse_args()
|
| 790 |
for k, v in kwargs.items():
|
| 791 |
if hasattr(args, k) and v is not None:
|
embodied_gen/data/backproject_v3.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
from typing import Literal, Union
|
| 22 |
+
|
| 23 |
+
import cv2
|
| 24 |
+
import numpy as np
|
| 25 |
+
import nvdiffrast.torch as dr
|
| 26 |
+
import spaces
|
| 27 |
+
import torch
|
| 28 |
+
import trimesh
|
| 29 |
+
import utils3d
|
| 30 |
+
import xatlas
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
from embodied_gen.data.mesh_operator import MeshFixer
|
| 34 |
+
from embodied_gen.data.utils import (
|
| 35 |
+
CameraSetting,
|
| 36 |
+
init_kal_camera,
|
| 37 |
+
kaolin_to_opencv_view,
|
| 38 |
+
normalize_vertices_array,
|
| 39 |
+
post_process_texture,
|
| 40 |
+
save_mesh_with_mtl,
|
| 41 |
+
)
|
| 42 |
+
from embodied_gen.models.delight_model import DelightingModel
|
| 43 |
+
from embodied_gen.models.gs_model import load_gs_model
|
| 44 |
+
from embodied_gen.models.sr_model import ImageRealESRGAN
|
| 45 |
+
|
| 46 |
+
logging.basicConfig(
|
| 47 |
+
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
| 48 |
+
)
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
__all__ = [
|
| 53 |
+
"TextureBaker",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TextureBaker(object):
|
| 58 |
+
"""Baking textures onto a mesh from multiple observations.
|
| 59 |
+
|
| 60 |
+
This class take 3D mesh data, camera settings and texture baking parameters
|
| 61 |
+
to generate texture map by projecting images to the mesh from diff views.
|
| 62 |
+
It supports both a fast texture baking approach and a more optimized method
|
| 63 |
+
with total variation regularization.
|
| 64 |
+
|
| 65 |
+
Attributes:
|
| 66 |
+
vertices (torch.Tensor): The vertices of the mesh.
|
| 67 |
+
faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
|
| 68 |
+
uvs (torch.Tensor): The UV coordinates of the mesh.
|
| 69 |
+
camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
|
| 70 |
+
device (str): The device to run computations on ("cpu" or "cuda").
|
| 71 |
+
w2cs (torch.Tensor): World-to-camera transformation matrices.
|
| 72 |
+
projections (torch.Tensor): Camera projection matrices.
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
>>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
|
| 76 |
+
>>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
|
| 77 |
+
>>> images = get_images_from_grid(args.color_path, image_size)
|
| 78 |
+
>>> texture = texture_backer.bake_texture(
|
| 79 |
+
... images, texture_size=args.texture_size, mode=args.baker_mode
|
| 80 |
+
... )
|
| 81 |
+
>>> texture = post_process_texture(texture)
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
vertices: np.ndarray,
|
| 87 |
+
faces: np.ndarray,
|
| 88 |
+
uvs: np.ndarray,
|
| 89 |
+
camera_params: CameraSetting,
|
| 90 |
+
device: str = "cuda",
|
| 91 |
+
) -> None:
|
| 92 |
+
self.vertices = (
|
| 93 |
+
torch.tensor(vertices, device=device)
|
| 94 |
+
if isinstance(vertices, np.ndarray)
|
| 95 |
+
else vertices.to(device)
|
| 96 |
+
)
|
| 97 |
+
self.faces = (
|
| 98 |
+
torch.tensor(faces.astype(np.int32), device=device)
|
| 99 |
+
if isinstance(faces, np.ndarray)
|
| 100 |
+
else faces.to(device)
|
| 101 |
+
)
|
| 102 |
+
self.uvs = (
|
| 103 |
+
torch.tensor(uvs, device=device)
|
| 104 |
+
if isinstance(uvs, np.ndarray)
|
| 105 |
+
else uvs.to(device)
|
| 106 |
+
)
|
| 107 |
+
self.camera_params = camera_params
|
| 108 |
+
self.device = device
|
| 109 |
+
|
| 110 |
+
camera = init_kal_camera(camera_params)
|
| 111 |
+
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
|
| 112 |
+
matrix_mv = kaolin_to_opencv_view(matrix_mv)
|
| 113 |
+
matrix_p = (
|
| 114 |
+
camera.intrinsics.projection_matrix()
|
| 115 |
+
) # (n_cam 4 4) cam2pixel
|
| 116 |
+
self.w2cs = matrix_mv.to(self.device)
|
| 117 |
+
self.projections = matrix_p.to(self.device)
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def parametrize_mesh(
|
| 121 |
+
vertices: np.array, faces: np.array
|
| 122 |
+
) -> Union[np.array, np.array, np.array]:
|
| 123 |
+
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
| 124 |
+
|
| 125 |
+
vertices = vertices[vmapping]
|
| 126 |
+
faces = indices
|
| 127 |
+
|
| 128 |
+
return vertices, faces, uvs
|
| 129 |
+
|
| 130 |
+
def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
|
| 131 |
+
texture = torch.zeros(
|
| 132 |
+
(texture_size * texture_size, 3), dtype=torch.float32
|
| 133 |
+
).cuda()
|
| 134 |
+
texture_weights = torch.zeros(
|
| 135 |
+
(texture_size * texture_size), dtype=torch.float32
|
| 136 |
+
).cuda()
|
| 137 |
+
rastctx = utils3d.torch.RastContext(backend="cuda")
|
| 138 |
+
for observation, w2c, projection in tqdm(
|
| 139 |
+
zip(observations, w2cs, projections),
|
| 140 |
+
total=len(observations),
|
| 141 |
+
desc="Texture baking (fast)",
|
| 142 |
+
):
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
| 145 |
+
rastctx,
|
| 146 |
+
self.vertices[None],
|
| 147 |
+
self.faces,
|
| 148 |
+
observation.shape[1],
|
| 149 |
+
observation.shape[0],
|
| 150 |
+
uv=self.uvs[None],
|
| 151 |
+
view=w2c,
|
| 152 |
+
projection=projection,
|
| 153 |
+
)
|
| 154 |
+
uv_map = rast["uv"][0].detach().flip(0)
|
| 155 |
+
mask = rast["mask"][0].detach().bool() & masks[0]
|
| 156 |
+
|
| 157 |
+
# nearest neighbor interpolation
|
| 158 |
+
uv_map = (uv_map * texture_size).floor().long()
|
| 159 |
+
obs = observation[mask]
|
| 160 |
+
uv_map = uv_map[mask]
|
| 161 |
+
idx = (
|
| 162 |
+
uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
| 163 |
+
)
|
| 164 |
+
texture = texture.scatter_add(
|
| 165 |
+
0, idx.view(-1, 1).expand(-1, 3), obs
|
| 166 |
+
)
|
| 167 |
+
texture_weights = texture_weights.scatter_add(
|
| 168 |
+
0,
|
| 169 |
+
idx,
|
| 170 |
+
torch.ones(
|
| 171 |
+
(obs.shape[0]), dtype=torch.float32, device=texture.device
|
| 172 |
+
),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
mask = texture_weights > 0
|
| 176 |
+
texture[mask] /= texture_weights[mask][:, None]
|
| 177 |
+
texture = np.clip(
|
| 178 |
+
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
|
| 179 |
+
0,
|
| 180 |
+
255,
|
| 181 |
+
).astype(np.uint8)
|
| 182 |
+
|
| 183 |
+
# inpaint
|
| 184 |
+
mask = (
|
| 185 |
+
(texture_weights == 0)
|
| 186 |
+
.cpu()
|
| 187 |
+
.numpy()
|
| 188 |
+
.astype(np.uint8)
|
| 189 |
+
.reshape(texture_size, texture_size)
|
| 190 |
+
)
|
| 191 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
| 192 |
+
|
| 193 |
+
return texture
|
| 194 |
+
|
| 195 |
+
def _bake_opt(
|
| 196 |
+
self,
|
| 197 |
+
observations,
|
| 198 |
+
w2cs,
|
| 199 |
+
projections,
|
| 200 |
+
texture_size,
|
| 201 |
+
lambda_tv,
|
| 202 |
+
masks,
|
| 203 |
+
total_steps,
|
| 204 |
+
):
|
| 205 |
+
rastctx = utils3d.torch.RastContext(backend="cuda")
|
| 206 |
+
observations = [observations.flip(0) for observations in observations]
|
| 207 |
+
masks = [m.flip(0) for m in masks]
|
| 208 |
+
_uv = []
|
| 209 |
+
_uv_dr = []
|
| 210 |
+
for observation, w2c, projection in tqdm(
|
| 211 |
+
zip(observations, w2cs, projections),
|
| 212 |
+
total=len(w2cs),
|
| 213 |
+
):
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
| 216 |
+
rastctx,
|
| 217 |
+
self.vertices[None],
|
| 218 |
+
self.faces,
|
| 219 |
+
observation.shape[1],
|
| 220 |
+
observation.shape[0],
|
| 221 |
+
uv=self.uvs[None],
|
| 222 |
+
view=w2c,
|
| 223 |
+
projection=projection,
|
| 224 |
+
)
|
| 225 |
+
_uv.append(rast["uv"].detach())
|
| 226 |
+
_uv_dr.append(rast["uv_dr"].detach())
|
| 227 |
+
|
| 228 |
+
texture = torch.nn.Parameter(
|
| 229 |
+
torch.zeros(
|
| 230 |
+
(1, texture_size, texture_size, 3), dtype=torch.float32
|
| 231 |
+
).cuda()
|
| 232 |
+
)
|
| 233 |
+
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
| 234 |
+
|
| 235 |
+
def cosine_anealing(step, total_steps, start_lr, end_lr):
|
| 236 |
+
return end_lr + 0.5 * (start_lr - end_lr) * (
|
| 237 |
+
1 + np.cos(np.pi * step / total_steps)
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def tv_loss(texture):
|
| 241 |
+
return torch.nn.functional.l1_loss(
|
| 242 |
+
texture[:, :-1, :, :], texture[:, 1:, :, :]
|
| 243 |
+
) + torch.nn.functional.l1_loss(
|
| 244 |
+
texture[:, :, :-1, :], texture[:, :, 1:, :]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
with tqdm(total=total_steps, desc="Texture baking") as pbar:
|
| 248 |
+
for step in range(total_steps):
|
| 249 |
+
optimizer.zero_grad()
|
| 250 |
+
selected = np.random.randint(0, len(w2cs))
|
| 251 |
+
uv, uv_dr, observation, mask = (
|
| 252 |
+
_uv[selected],
|
| 253 |
+
_uv_dr[selected],
|
| 254 |
+
observations[selected],
|
| 255 |
+
masks[selected],
|
| 256 |
+
)
|
| 257 |
+
render = dr.texture(texture, uv, uv_dr)[0]
|
| 258 |
+
loss = torch.nn.functional.l1_loss(
|
| 259 |
+
render[mask], observation[mask]
|
| 260 |
+
)
|
| 261 |
+
if lambda_tv > 0:
|
| 262 |
+
loss += lambda_tv * tv_loss(texture)
|
| 263 |
+
loss.backward()
|
| 264 |
+
optimizer.step()
|
| 265 |
+
|
| 266 |
+
optimizer.param_groups[0]["lr"] = cosine_anealing(
|
| 267 |
+
step, total_steps, 1e-2, 1e-5
|
| 268 |
+
)
|
| 269 |
+
pbar.set_postfix({"loss": loss.item()})
|
| 270 |
+
pbar.update()
|
| 271 |
+
|
| 272 |
+
texture = np.clip(
|
| 273 |
+
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
|
| 274 |
+
).astype(np.uint8)
|
| 275 |
+
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
| 276 |
+
rastctx,
|
| 277 |
+
(self.uvs * 2 - 1)[None],
|
| 278 |
+
self.faces,
|
| 279 |
+
texture_size,
|
| 280 |
+
texture_size,
|
| 281 |
+
)["mask"][0].detach().cpu().numpy().astype(np.uint8)
|
| 282 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
| 283 |
+
|
| 284 |
+
return texture
|
| 285 |
+
|
| 286 |
+
def bake_texture(
|
| 287 |
+
self,
|
| 288 |
+
images: list[np.array],
|
| 289 |
+
texture_size: int = 1024,
|
| 290 |
+
mode: Literal["fast", "opt"] = "opt",
|
| 291 |
+
lambda_tv: float = 1e-2,
|
| 292 |
+
opt_step: int = 2000,
|
| 293 |
+
):
|
| 294 |
+
masks = [np.any(img > 0, axis=-1) for img in images]
|
| 295 |
+
masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
|
| 296 |
+
images = [
|
| 297 |
+
torch.tensor(obs / 255.0).float().to(self.device) for obs in images
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
if mode == "fast":
|
| 301 |
+
return self._bake_fast(
|
| 302 |
+
images, self.w2cs, self.projections, texture_size, masks
|
| 303 |
+
)
|
| 304 |
+
elif mode == "opt":
|
| 305 |
+
return self._bake_opt(
|
| 306 |
+
images,
|
| 307 |
+
self.w2cs,
|
| 308 |
+
self.projections,
|
| 309 |
+
texture_size,
|
| 310 |
+
lambda_tv,
|
| 311 |
+
masks,
|
| 312 |
+
opt_step,
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
raise ValueError(f"Unknown mode: {mode}")
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def parse_args():
|
| 319 |
+
"""Parses command-line arguments for texture backprojection.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
argparse.Namespace: Parsed arguments.
|
| 323 |
+
"""
|
| 324 |
+
parser = argparse.ArgumentParser(description="Backproject texture")
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--gs_path",
|
| 327 |
+
type=str,
|
| 328 |
+
help="Path to the GS.ply gaussian splatting model",
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--mesh_path",
|
| 332 |
+
type=str,
|
| 333 |
+
help="Mesh path, .obj, .glb or .ply",
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--output_path",
|
| 337 |
+
type=str,
|
| 338 |
+
help="Output mesh path with suffix",
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--num_images",
|
| 342 |
+
type=int,
|
| 343 |
+
default=180,
|
| 344 |
+
help="Number of images to render.",
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--elevation",
|
| 348 |
+
nargs="+",
|
| 349 |
+
type=float,
|
| 350 |
+
default=list(range(85, -90, -10)),
|
| 351 |
+
help="Elevation angles for the camera",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--distance",
|
| 355 |
+
type=float,
|
| 356 |
+
default=5,
|
| 357 |
+
help="Camera distance (default: 5)",
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--resolution_hw",
|
| 361 |
+
type=int,
|
| 362 |
+
nargs=2,
|
| 363 |
+
default=(512, 512),
|
| 364 |
+
help="Resolution of the render images (default: (512, 512))",
|
| 365 |
+
)
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--fov",
|
| 368 |
+
type=float,
|
| 369 |
+
default=30,
|
| 370 |
+
help="Field of view in degrees (default: 30)",
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--device",
|
| 374 |
+
type=str,
|
| 375 |
+
choices=["cpu", "cuda"],
|
| 376 |
+
default="cuda",
|
| 377 |
+
help="Device to run on (default: `cuda`)",
|
| 378 |
+
)
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
|
| 381 |
+
)
|
| 382 |
+
parser.add_argument(
|
| 383 |
+
"--texture_size",
|
| 384 |
+
type=int,
|
| 385 |
+
default=2048,
|
| 386 |
+
help="Texture size for texture baking (default: 1024)",
|
| 387 |
+
)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--baker_mode",
|
| 390 |
+
type=str,
|
| 391 |
+
default="opt",
|
| 392 |
+
help="Texture baking mode, `fast` or `opt` (default: opt)",
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--opt_step",
|
| 396 |
+
type=int,
|
| 397 |
+
default=3000,
|
| 398 |
+
help="Optimization steps for texture baking (default: 3000)",
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--mesh_sipmlify_ratio",
|
| 402 |
+
type=float,
|
| 403 |
+
default=0.9,
|
| 404 |
+
help="Mesh simplification ratio (default: 0.9)",
|
| 405 |
+
)
|
| 406 |
+
parser.add_argument(
|
| 407 |
+
"--delight", action="store_true", help="Use delighting model."
|
| 408 |
+
)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--no_smooth_texture",
|
| 411 |
+
action="store_true",
|
| 412 |
+
help="Do not smooth the texture.",
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--no_coor_trans",
|
| 416 |
+
action="store_true",
|
| 417 |
+
help="Do not transform the asset coordinate system.",
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--save_glb_path", type=str, default=None, help="Save glb path."
|
| 421 |
+
)
|
| 422 |
+
parser.add_argument("--n_max_faces", type=int, default=30000)
|
| 423 |
+
args, unknown = parser.parse_known_args()
|
| 424 |
+
|
| 425 |
+
return args
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def entrypoint(
|
| 429 |
+
delight_model: DelightingModel = None,
|
| 430 |
+
imagesr_model: ImageRealESRGAN = None,
|
| 431 |
+
**kwargs,
|
| 432 |
+
) -> trimesh.Trimesh:
|
| 433 |
+
"""Entrypoint for texture backprojection from multi-view images.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
delight_model (DelightingModel, optional): Delighting model.
|
| 437 |
+
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
| 438 |
+
**kwargs: Additional arguments to override CLI.
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
trimesh.Trimesh: Textured mesh.
|
| 442 |
+
"""
|
| 443 |
+
args = parse_args()
|
| 444 |
+
for k, v in kwargs.items():
|
| 445 |
+
if hasattr(args, k) and v is not None:
|
| 446 |
+
setattr(args, k, v)
|
| 447 |
+
|
| 448 |
+
# Setup camera parameters.
|
| 449 |
+
camera_params = CameraSetting(
|
| 450 |
+
num_images=args.num_images,
|
| 451 |
+
elevation=args.elevation,
|
| 452 |
+
distance=args.distance,
|
| 453 |
+
resolution_hw=args.resolution_hw,
|
| 454 |
+
fov=math.radians(args.fov),
|
| 455 |
+
device=args.device,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# GS render.
|
| 459 |
+
camera = init_kal_camera(camera_params, flip_az=True)
|
| 460 |
+
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
|
| 461 |
+
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
|
| 462 |
+
w2cs = matrix_mv.to(camera_params.device)
|
| 463 |
+
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
|
| 464 |
+
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
|
| 465 |
+
gs_model = load_gs_model(args.gs_path, pre_quat=[0.0, 0.0, 1.0, 0.0])
|
| 466 |
+
multiviews = []
|
| 467 |
+
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
|
| 468 |
+
result = gs_model.render(
|
| 469 |
+
c2ws[idx],
|
| 470 |
+
Ks=Ks,
|
| 471 |
+
image_width=camera_params.resolution_hw[1],
|
| 472 |
+
image_height=camera_params.resolution_hw[0],
|
| 473 |
+
)
|
| 474 |
+
color = cv2.cvtColor(result.rgba, cv2.COLOR_BGRA2RGBA)
|
| 475 |
+
multiviews.append(Image.fromarray(color))
|
| 476 |
+
|
| 477 |
+
if args.delight and delight_model is None:
|
| 478 |
+
delight_model = DelightingModel()
|
| 479 |
+
|
| 480 |
+
if args.delight:
|
| 481 |
+
for idx in range(len(multiviews)):
|
| 482 |
+
multiviews[idx] = delight_model(multiviews[idx])
|
| 483 |
+
|
| 484 |
+
multiviews = [img.convert("RGB") for img in multiviews]
|
| 485 |
+
|
| 486 |
+
mesh = trimesh.load(args.mesh_path)
|
| 487 |
+
if isinstance(mesh, trimesh.Scene):
|
| 488 |
+
mesh = mesh.dump(concatenate=True)
|
| 489 |
+
|
| 490 |
+
vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
| 491 |
+
|
| 492 |
+
# Transform mesh coordinate system by default.
|
| 493 |
+
if not args.no_coor_trans:
|
| 494 |
+
x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
|
| 495 |
+
z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
|
| 496 |
+
vertices = vertices @ x_rot
|
| 497 |
+
vertices = vertices @ z_rot
|
| 498 |
+
|
| 499 |
+
faces = mesh.faces.astype(np.int32)
|
| 500 |
+
vertices = vertices.astype(np.float32)
|
| 501 |
+
|
| 502 |
+
if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
|
| 503 |
+
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 504 |
+
vertices, faces = mesh_fixer(
|
| 505 |
+
filter_ratio=args.mesh_sipmlify_ratio,
|
| 506 |
+
max_hole_size=0.04,
|
| 507 |
+
resolution=1024,
|
| 508 |
+
num_views=1000,
|
| 509 |
+
norm_mesh_ratio=0.5,
|
| 510 |
+
)
|
| 511 |
+
if len(faces) > args.n_max_faces:
|
| 512 |
+
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 513 |
+
vertices, faces = mesh_fixer(
|
| 514 |
+
filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
|
| 515 |
+
max_hole_size=0.04,
|
| 516 |
+
resolution=1024,
|
| 517 |
+
num_views=1000,
|
| 518 |
+
norm_mesh_ratio=0.5,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
|
| 522 |
+
texture_backer = TextureBaker(
|
| 523 |
+
vertices,
|
| 524 |
+
faces,
|
| 525 |
+
uvs,
|
| 526 |
+
camera_params,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
multiviews = [np.array(img) for img in multiviews]
|
| 530 |
+
texture = texture_backer.bake_texture(
|
| 531 |
+
images=[img[..., :3] for img in multiviews],
|
| 532 |
+
texture_size=args.texture_size,
|
| 533 |
+
mode=args.baker_mode,
|
| 534 |
+
opt_step=args.opt_step,
|
| 535 |
+
)
|
| 536 |
+
if not args.no_smooth_texture:
|
| 537 |
+
texture = post_process_texture(texture)
|
| 538 |
+
|
| 539 |
+
# Recover mesh original orientation, scale and center.
|
| 540 |
+
if not args.no_coor_trans:
|
| 541 |
+
vertices = vertices @ np.linalg.inv(z_rot)
|
| 542 |
+
vertices = vertices @ np.linalg.inv(x_rot)
|
| 543 |
+
vertices = vertices / scale
|
| 544 |
+
vertices = vertices + center
|
| 545 |
+
|
| 546 |
+
textured_mesh = save_mesh_with_mtl(
|
| 547 |
+
vertices, faces, uvs, texture, args.output_path
|
| 548 |
+
)
|
| 549 |
+
if args.save_glb_path is not None:
|
| 550 |
+
os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
|
| 551 |
+
textured_mesh.export(args.save_glb_path)
|
| 552 |
+
|
| 553 |
+
return textured_mesh
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
if __name__ == "__main__":
|
| 557 |
+
entrypoint()
|
embodied_gen/data/convex_decomposer.py
CHANGED
|
@@ -39,6 +39,22 @@ def decompose_convex_coacd(
|
|
| 39 |
auto_scale: bool = True,
|
| 40 |
scale_factor: float = 1.0,
|
| 41 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
coacd.set_log_level("info" if verbose else "warn")
|
| 43 |
|
| 44 |
mesh = trimesh.load(filename, force="mesh")
|
|
@@ -83,7 +99,38 @@ def decompose_convex_mesh(
|
|
| 83 |
scale_factor: float = 1.005,
|
| 84 |
verbose: bool = False,
|
| 85 |
) -> str:
|
| 86 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
coacd.set_log_level("info" if verbose else "warn")
|
| 88 |
|
| 89 |
if os.path.exists(outfile):
|
|
@@ -148,9 +195,37 @@ def decompose_convex_mp(
|
|
| 148 |
verbose: bool = False,
|
| 149 |
auto_scale: bool = True,
|
| 150 |
) -> str:
|
| 151 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
"""
|
| 155 |
params = dict(
|
| 156 |
threshold=threshold,
|
|
|
|
| 39 |
auto_scale: bool = True,
|
| 40 |
scale_factor: float = 1.0,
|
| 41 |
) -> None:
|
| 42 |
+
"""Decomposes a mesh using CoACD and saves the result.
|
| 43 |
+
|
| 44 |
+
This function loads a mesh from a file, runs the CoACD algorithm with the
|
| 45 |
+
given parameters, optionally scales the resulting convex hulls to match the
|
| 46 |
+
original mesh's bounding box, and exports the combined result to a file.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
filename: Path to the input mesh file.
|
| 50 |
+
outfile: Path to save the decomposed output mesh.
|
| 51 |
+
params: A dictionary of parameters for the CoACD algorithm.
|
| 52 |
+
verbose: If True, sets the CoACD log level to 'info'.
|
| 53 |
+
auto_scale: If True, automatically computes a scale factor to match the
|
| 54 |
+
decomposed mesh's bounding box to the visual mesh's bounding box.
|
| 55 |
+
scale_factor: An additional scaling factor applied to the vertices of
|
| 56 |
+
the decomposed mesh parts.
|
| 57 |
+
"""
|
| 58 |
coacd.set_log_level("info" if verbose else "warn")
|
| 59 |
|
| 60 |
mesh = trimesh.load(filename, force="mesh")
|
|
|
|
| 99 |
scale_factor: float = 1.005,
|
| 100 |
verbose: bool = False,
|
| 101 |
) -> str:
|
| 102 |
+
"""Decomposes a mesh into convex parts with retry logic.
|
| 103 |
+
|
| 104 |
+
This function serves as a wrapper for `decompose_convex_coacd`, providing
|
| 105 |
+
explicit parameters for the CoACD algorithm and implementing a retry
|
| 106 |
+
mechanism. If the initial decomposition fails, it attempts again with
|
| 107 |
+
`preprocess_mode` set to 'on'.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
filename: Path to the input mesh file.
|
| 111 |
+
outfile: Path to save the decomposed output mesh.
|
| 112 |
+
threshold: CoACD parameter. See CoACD documentation for details.
|
| 113 |
+
max_convex_hull: CoACD parameter. See CoACD documentation for details.
|
| 114 |
+
preprocess_mode: CoACD parameter. See CoACD documentation for details.
|
| 115 |
+
preprocess_resolution: CoACD parameter. See CoACD documentation for details.
|
| 116 |
+
resolution: CoACD parameter. See CoACD documentation for details.
|
| 117 |
+
mcts_nodes: CoACD parameter. See CoACD documentation for details.
|
| 118 |
+
mcts_iterations: CoACD parameter. See CoACD documentation for details.
|
| 119 |
+
mcts_max_depth: CoACD parameter. See CoACD documentation for details.
|
| 120 |
+
pca: CoACD parameter. See CoACD documentation for details.
|
| 121 |
+
merge: CoACD parameter. See CoACD documentation for details.
|
| 122 |
+
seed: CoACD parameter. See CoACD documentation for details.
|
| 123 |
+
auto_scale: If True, automatically scale the output to match the input
|
| 124 |
+
bounding box.
|
| 125 |
+
scale_factor: Additional scaling factor to apply.
|
| 126 |
+
verbose: If True, enables detailed logging.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
The path to the output file if decomposition is successful.
|
| 130 |
+
|
| 131 |
+
Raises:
|
| 132 |
+
RuntimeError: If convex decomposition fails after all attempts.
|
| 133 |
+
"""
|
| 134 |
coacd.set_log_level("info" if verbose else "warn")
|
| 135 |
|
| 136 |
if os.path.exists(outfile):
|
|
|
|
| 195 |
verbose: bool = False,
|
| 196 |
auto_scale: bool = True,
|
| 197 |
) -> str:
|
| 198 |
+
"""Decomposes a mesh into convex parts in a separate process.
|
| 199 |
+
|
| 200 |
+
This function uses the `multiprocessing` module to run the CoACD algorithm
|
| 201 |
+
in a spawned subprocess. This is useful for isolating the decomposition
|
| 202 |
+
process to prevent potential memory leaks or crashes in the main process.
|
| 203 |
+
It includes a retry mechanism similar to `decompose_convex_mesh`.
|
| 204 |
|
| 205 |
See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
filename: Path to the input mesh file.
|
| 209 |
+
outfile: Path to save the decomposed output mesh.
|
| 210 |
+
threshold: CoACD parameter.
|
| 211 |
+
max_convex_hull: CoACD parameter.
|
| 212 |
+
preprocess_mode: CoACD parameter.
|
| 213 |
+
preprocess_resolution: CoACD parameter.
|
| 214 |
+
resolution: CoACD parameter.
|
| 215 |
+
mcts_nodes: CoACD parameter.
|
| 216 |
+
mcts_iterations: CoACD parameter.
|
| 217 |
+
mcts_max_depth: CoACD parameter.
|
| 218 |
+
pca: CoACD parameter.
|
| 219 |
+
merge: CoACD parameter.
|
| 220 |
+
seed: CoACD parameter.
|
| 221 |
+
verbose: If True, enables detailed logging in the subprocess.
|
| 222 |
+
auto_scale: If True, automatically scale the output.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
The path to the output file if decomposition is successful.
|
| 226 |
+
|
| 227 |
+
Raises:
|
| 228 |
+
RuntimeError: If convex decomposition fails after all attempts.
|
| 229 |
"""
|
| 230 |
params = dict(
|
| 231 |
threshold=threshold,
|
embodied_gen/data/differentiable_render.py
CHANGED
|
@@ -66,6 +66,14 @@ def create_mp4_from_images(
|
|
| 66 |
fps: int = 10,
|
| 67 |
prompt: str = None,
|
| 68 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 70 |
font_scale = 0.5
|
| 71 |
font_thickness = 1
|
|
@@ -96,6 +104,13 @@ def create_mp4_from_images(
|
|
| 96 |
def create_gif_from_images(
|
| 97 |
images: list[np.ndarray], output_path: str, fps: int = 10
|
| 98 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
pil_images = []
|
| 100 |
for image in images:
|
| 101 |
image = image.clip(min=0, max=1)
|
|
@@ -116,32 +131,47 @@ def create_gif_from_images(
|
|
| 116 |
|
| 117 |
|
| 118 |
class ImageRender(object):
|
| 119 |
-
"""
|
| 120 |
|
| 121 |
-
This class wraps
|
| 122 |
-
|
|
|
|
| 123 |
|
| 124 |
Args:
|
| 125 |
-
render_items (list[RenderItems]):
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
"""
|
| 146 |
|
| 147 |
def __init__(
|
|
@@ -198,6 +228,14 @@ class ImageRender(object):
|
|
| 198 |
uuid: Union[str, List[str]] = None,
|
| 199 |
prompts: List[str] = None,
|
| 200 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
mesh_path = as_list(mesh_path)
|
| 202 |
if uuid is None:
|
| 203 |
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
|
@@ -227,18 +265,15 @@ class ImageRender(object):
|
|
| 227 |
def __call__(
|
| 228 |
self, mesh_path: str, output_dir: str, prompt: str = None
|
| 229 |
) -> dict[str, str]:
|
| 230 |
-
"""
|
| 231 |
-
|
| 232 |
-
Processes the input mesh, renders multiple modalities (e.g., normals,
|
| 233 |
-
depth, albedo), and optionally saves video or image sequences.
|
| 234 |
|
| 235 |
Args:
|
| 236 |
-
mesh_path (str): Path to
|
| 237 |
-
output_dir (str): Directory to save
|
| 238 |
-
prompt (str, optional):
|
| 239 |
|
| 240 |
Returns:
|
| 241 |
-
dict[str, str]:
|
| 242 |
"""
|
| 243 |
try:
|
| 244 |
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
|
|
|
| 66 |
fps: int = 10,
|
| 67 |
prompt: str = None,
|
| 68 |
):
|
| 69 |
+
"""Creates an MP4 video from a list of images.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
images (list[np.ndarray]): List of images as numpy arrays.
|
| 73 |
+
output_path (str): Path to save the MP4 file.
|
| 74 |
+
fps (int, optional): Frames per second. Defaults to 10.
|
| 75 |
+
prompt (str, optional): Optional text prompt overlay.
|
| 76 |
+
"""
|
| 77 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 78 |
font_scale = 0.5
|
| 79 |
font_thickness = 1
|
|
|
|
| 104 |
def create_gif_from_images(
|
| 105 |
images: list[np.ndarray], output_path: str, fps: int = 10
|
| 106 |
) -> None:
|
| 107 |
+
"""Creates a GIF animation from a list of images.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
images (list[np.ndarray]): List of images as numpy arrays.
|
| 111 |
+
output_path (str): Path to save the GIF file.
|
| 112 |
+
fps (int, optional): Frames per second. Defaults to 10.
|
| 113 |
+
"""
|
| 114 |
pil_images = []
|
| 115 |
for image in images:
|
| 116 |
image = image.clip(min=0, max=1)
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
class ImageRender(object):
|
| 134 |
+
"""Differentiable mesh renderer supporting multi-view rendering.
|
| 135 |
|
| 136 |
+
This class wraps differentiable rasterization using `nvdiffrast` to render mesh
|
| 137 |
+
geometry to various maps (normal, depth, alpha, albedo, etc.) and supports
|
| 138 |
+
saving images and videos.
|
| 139 |
|
| 140 |
Args:
|
| 141 |
+
render_items (list[RenderItems]): List of rendering targets.
|
| 142 |
+
camera_params (CameraSetting): Camera parameters for rendering.
|
| 143 |
+
recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True.
|
| 144 |
+
with_mtl (bool, optional): Load mesh material files. Defaults to False.
|
| 145 |
+
gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False.
|
| 146 |
+
gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False.
|
| 147 |
+
gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False.
|
| 148 |
+
gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False.
|
| 149 |
+
no_index_file (bool, optional): Skip saving index file. Defaults to False.
|
| 150 |
+
light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0.
|
| 151 |
+
|
| 152 |
+
Example:
|
| 153 |
+
```py
|
| 154 |
+
from embodied_gen.data.differentiable_render import ImageRender
|
| 155 |
+
from embodied_gen.data.utils import CameraSetting
|
| 156 |
+
from embodied_gen.utils.enum import RenderItems
|
| 157 |
+
|
| 158 |
+
camera_params = CameraSetting(
|
| 159 |
+
num_images=6,
|
| 160 |
+
elevation=[20, -10],
|
| 161 |
+
distance=5,
|
| 162 |
+
resolution_hw=(512,512),
|
| 163 |
+
fov=math.radians(30),
|
| 164 |
+
device='cuda',
|
| 165 |
+
)
|
| 166 |
+
render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value]
|
| 167 |
+
renderer = ImageRender(
|
| 168 |
+
render_items,
|
| 169 |
+
camera_params,
|
| 170 |
+
with_mtl=args.with_mtl,
|
| 171 |
+
gen_color_mp4=True,
|
| 172 |
+
)
|
| 173 |
+
renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders')
|
| 174 |
+
```
|
| 175 |
"""
|
| 176 |
|
| 177 |
def __init__(
|
|
|
|
| 228 |
uuid: Union[str, List[str]] = None,
|
| 229 |
prompts: List[str] = None,
|
| 230 |
) -> None:
|
| 231 |
+
"""Renders one or more meshes and saves outputs.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
mesh_path (Union[str, List[str]]): Path(s) to mesh files.
|
| 235 |
+
output_root (str): Directory to save outputs.
|
| 236 |
+
uuid (Union[str, List[str]], optional): Unique IDs for outputs.
|
| 237 |
+
prompts (List[str], optional): Text prompts for videos.
|
| 238 |
+
"""
|
| 239 |
mesh_path = as_list(mesh_path)
|
| 240 |
if uuid is None:
|
| 241 |
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
|
|
|
|
| 265 |
def __call__(
|
| 266 |
self, mesh_path: str, output_dir: str, prompt: str = None
|
| 267 |
) -> dict[str, str]:
|
| 268 |
+
"""Renders a single mesh and returns output paths.
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
Args:
|
| 271 |
+
mesh_path (str): Path to mesh file.
|
| 272 |
+
output_dir (str): Directory to save outputs.
|
| 273 |
+
prompt (str, optional): Caption prompt for MP4 metadata.
|
| 274 |
|
| 275 |
Returns:
|
| 276 |
+
dict[str, str]: Mapping of render types to saved image paths.
|
| 277 |
"""
|
| 278 |
try:
|
| 279 |
mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
|
embodied_gen/data/mesh_operator.py
CHANGED
|
@@ -16,17 +16,13 @@
|
|
| 16 |
|
| 17 |
|
| 18 |
import logging
|
| 19 |
-
import multiprocessing as mp
|
| 20 |
-
import os
|
| 21 |
from typing import Tuple, Union
|
| 22 |
|
| 23 |
-
import coacd
|
| 24 |
import igraph
|
| 25 |
import numpy as np
|
| 26 |
import pyvista as pv
|
| 27 |
import spaces
|
| 28 |
import torch
|
| 29 |
-
import trimesh
|
| 30 |
import utils3d
|
| 31 |
from pymeshfix import _meshfix
|
| 32 |
from tqdm import tqdm
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
import logging
|
|
|
|
|
|
|
| 19 |
from typing import Tuple, Union
|
| 20 |
|
|
|
|
| 21 |
import igraph
|
| 22 |
import numpy as np
|
| 23 |
import pyvista as pv
|
| 24 |
import spaces
|
| 25 |
import torch
|
|
|
|
| 26 |
import utils3d
|
| 27 |
from pymeshfix import _meshfix
|
| 28 |
from tqdm import tqdm
|
embodied_gen/data/utils.py
CHANGED
|
@@ -66,6 +66,7 @@ __all__ = [
|
|
| 66 |
"resize_pil",
|
| 67 |
"trellis_preprocess",
|
| 68 |
"delete_dir",
|
|
|
|
| 69 |
]
|
| 70 |
|
| 71 |
|
|
@@ -373,10 +374,18 @@ def _compute_az_el_by_views(
|
|
| 373 |
def _compute_cam_pts_by_az_el(
|
| 374 |
azs: np.ndarray,
|
| 375 |
els: np.ndarray,
|
| 376 |
-
distance: float,
|
| 377 |
extra_pts: np.ndarray = None,
|
| 378 |
) -> np.ndarray:
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
cam_pts = _az_el_to_points(azs, els) * distances[:, None]
|
| 381 |
|
| 382 |
if extra_pts is not None:
|
|
@@ -710,7 +719,7 @@ class CameraSetting:
|
|
| 710 |
|
| 711 |
num_images: int
|
| 712 |
elevation: list[float]
|
| 713 |
-
distance: float
|
| 714 |
resolution_hw: tuple[int, int]
|
| 715 |
fov: float
|
| 716 |
at: tuple[float, float, float] = field(
|
|
@@ -824,6 +833,28 @@ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
|
|
| 824 |
return mesh
|
| 825 |
|
| 826 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
def save_mesh_with_mtl(
|
| 828 |
vertices: np.ndarray,
|
| 829 |
faces: np.ndarray,
|
|
|
|
| 66 |
"resize_pil",
|
| 67 |
"trellis_preprocess",
|
| 68 |
"delete_dir",
|
| 69 |
+
"kaolin_to_opencv_view",
|
| 70 |
]
|
| 71 |
|
| 72 |
|
|
|
|
| 374 |
def _compute_cam_pts_by_az_el(
|
| 375 |
azs: np.ndarray,
|
| 376 |
els: np.ndarray,
|
| 377 |
+
distance: float | list[float] | np.ndarray,
|
| 378 |
extra_pts: np.ndarray = None,
|
| 379 |
) -> np.ndarray:
|
| 380 |
+
if np.isscalar(distance) or isinstance(distance, (float, int)):
|
| 381 |
+
distances = np.full(len(azs), distance)
|
| 382 |
+
else:
|
| 383 |
+
distances = np.array(distance)
|
| 384 |
+
if len(distances) != len(azs):
|
| 385 |
+
raise ValueError(
|
| 386 |
+
f"Length of distances ({len(distances)}) must match length of azs ({len(azs)})"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
cam_pts = _az_el_to_points(azs, els) * distances[:, None]
|
| 390 |
|
| 391 |
if extra_pts is not None:
|
|
|
|
| 719 |
|
| 720 |
num_images: int
|
| 721 |
elevation: list[float]
|
| 722 |
+
distance: float | list[float]
|
| 723 |
resolution_hw: tuple[int, int]
|
| 724 |
fov: float
|
| 725 |
at: tuple[float, float, float] = field(
|
|
|
|
| 833 |
return mesh
|
| 834 |
|
| 835 |
|
| 836 |
+
def kaolin_to_opencv_view(raw_matrix):
|
| 837 |
+
R_orig = raw_matrix[:, :3, :3]
|
| 838 |
+
t_orig = raw_matrix[:, :3, 3]
|
| 839 |
+
|
| 840 |
+
R_target = torch.zeros_like(R_orig)
|
| 841 |
+
R_target[:, :, 0] = R_orig[:, :, 2]
|
| 842 |
+
R_target[:, :, 1] = R_orig[:, :, 0]
|
| 843 |
+
R_target[:, :, 2] = R_orig[:, :, 1]
|
| 844 |
+
|
| 845 |
+
t_target = t_orig
|
| 846 |
+
|
| 847 |
+
target_matrix = (
|
| 848 |
+
torch.eye(4, device=raw_matrix.device)
|
| 849 |
+
.unsqueeze(0)
|
| 850 |
+
.repeat(raw_matrix.size(0), 1, 1)
|
| 851 |
+
)
|
| 852 |
+
target_matrix[:, :3, :3] = R_target
|
| 853 |
+
target_matrix[:, :3, 3] = t_target
|
| 854 |
+
|
| 855 |
+
return target_matrix
|
| 856 |
+
|
| 857 |
+
|
| 858 |
def save_mesh_with_mtl(
|
| 859 |
vertices: np.ndarray,
|
| 860 |
faces: np.ndarray,
|
embodied_gen/envs/pick_embodiedgen.py
CHANGED
|
@@ -51,6 +51,33 @@ __all__ = ["PickEmbodiedGen"]
|
|
| 51 |
|
| 52 |
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
| 53 |
class PickEmbodiedGen(BaseEnv):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
| 55 |
goal_thresh = 0.0
|
| 56 |
|
|
@@ -63,6 +90,19 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 63 |
reconfiguration_freq: int = None,
|
| 64 |
**kwargs,
|
| 65 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
self.robot_init_qpos_noise = robot_init_qpos_noise
|
| 67 |
if reconfiguration_freq is None:
|
| 68 |
if num_envs == 1:
|
|
@@ -116,6 +156,22 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 116 |
def init_env_layouts(
|
| 117 |
layout_file: str, num_envs: int, replace_objs: bool
|
| 118 |
) -> list[LayoutInfo]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
layouts = []
|
| 120 |
for env_idx in range(num_envs):
|
| 121 |
if replace_objs and env_idx > 0:
|
|
@@ -136,6 +192,18 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 136 |
def compute_robot_init_pose(
|
| 137 |
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
| 138 |
) -> list[list[float]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
robot_pose = []
|
| 140 |
for env_idx in range(num_envs):
|
| 141 |
layout = json.load(open(layouts[env_idx], "r"))
|
|
@@ -148,6 +216,11 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 148 |
|
| 149 |
@property
|
| 150 |
def _default_sim_config(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
return SimConfig(
|
| 152 |
scene_config=SceneConfig(
|
| 153 |
solver_position_iterations=30,
|
|
@@ -163,6 +236,11 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 163 |
|
| 164 |
@property
|
| 165 |
def _default_sensor_configs(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
| 167 |
|
| 168 |
return [
|
|
@@ -171,6 +249,11 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 171 |
|
| 172 |
@property
|
| 173 |
def _default_human_render_camera_configs(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
pose = sapien_utils.look_at(
|
| 175 |
eye=self.camera_cfg["camera_eye"],
|
| 176 |
target=self.camera_cfg["camera_target_pt"],
|
|
@@ -187,10 +270,24 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 187 |
)
|
| 188 |
|
| 189 |
def _load_agent(self, options: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
self.ground = build_ground(self.scene)
|
| 191 |
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
| 192 |
|
| 193 |
def _load_scene(self, options: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
all_objects = []
|
| 195 |
logger.info(f"Loading EmbodiedGen assets...")
|
| 196 |
for env_idx in range(self.num_envs):
|
|
@@ -222,6 +319,15 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 222 |
self._hidden_objects.append(self.goal_site)
|
| 223 |
|
| 224 |
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
with torch.device(self.device):
|
| 226 |
b = len(env_idx)
|
| 227 |
goal_xyz = torch.zeros((b, 3))
|
|
@@ -256,6 +362,21 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 256 |
def render_gs3d_images(
|
| 257 |
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
| 258 |
) -> dict[str, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
sim_coord_align = (
|
| 260 |
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
| 261 |
)
|
|
@@ -293,6 +414,15 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 293 |
return bg_images
|
| 294 |
|
| 295 |
def render(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
if self.render_mode is None:
|
| 297 |
raise RuntimeError("render_mode is not set.")
|
| 298 |
if self.render_mode == "human":
|
|
@@ -315,6 +445,17 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 315 |
def render_rgb_array(
|
| 316 |
self, camera_name: str = None, return_alpha: bool = False
|
| 317 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
for obj in self._hidden_objects:
|
| 319 |
obj.show_visual()
|
| 320 |
self.scene.update_render(
|
|
@@ -335,6 +476,11 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 335 |
return tile_images(images)
|
| 336 |
|
| 337 |
def render_sensors(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
images = []
|
| 339 |
sensor_images = self.get_sensor_images()
|
| 340 |
for image in sensor_images.values():
|
|
@@ -343,6 +489,14 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 343 |
return tile_images(images)
|
| 344 |
|
| 345 |
def hybrid_render(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
fg_images = self.render_rgb_array(
|
| 347 |
return_alpha=True
|
| 348 |
) # (n_env, h, w, 3)
|
|
@@ -362,6 +516,16 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 362 |
return images[..., :3]
|
| 363 |
|
| 364 |
def evaluate(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
obj_to_goal_pos = (
|
| 366 |
self.obj.pose.p
|
| 367 |
) # self.goal_site.pose.p - self.obj.pose.p
|
|
@@ -381,10 +545,31 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 381 |
)
|
| 382 |
|
| 383 |
def _get_obs_extra(self, info: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
return dict()
|
| 386 |
|
| 387 |
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
tcp_to_obj_dist = torch.linalg.norm(
|
| 389 |
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
| 390 |
)
|
|
@@ -417,4 +602,14 @@ class PickEmbodiedGen(BaseEnv):
|
|
| 417 |
def compute_normalized_dense_reward(
|
| 418 |
self, obs: any, action: torch.Tensor, info: dict
|
| 419 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
|
|
|
| 51 |
|
| 52 |
@register_env("PickEmbodiedGen-v1", max_episode_steps=100)
|
| 53 |
class PickEmbodiedGen(BaseEnv):
|
| 54 |
+
"""PickEmbodiedGen as gym env example for object pick-and-place tasks.
|
| 55 |
+
|
| 56 |
+
This environment simulates a robot interacting with 3D assets in the
|
| 57 |
+
embodiedgen generated scene in SAPIEN. It supports multi-environment setups,
|
| 58 |
+
dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting.
|
| 59 |
+
|
| 60 |
+
Example:
|
| 61 |
+
Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment.
|
| 62 |
+
```python
|
| 63 |
+
import gymnasium as gym
|
| 64 |
+
env = gym.make(
|
| 65 |
+
"PickEmbodiedGen-v1",
|
| 66 |
+
num_envs=cfg.num_envs,
|
| 67 |
+
render_mode=cfg.render_mode,
|
| 68 |
+
enable_shadow=cfg.enable_shadow,
|
| 69 |
+
layout_file=cfg.layout_file,
|
| 70 |
+
control_mode=cfg.control_mode,
|
| 71 |
+
camera_cfg=dict(
|
| 72 |
+
camera_eye=cfg.camera_eye,
|
| 73 |
+
camera_target_pt=cfg.camera_target_pt,
|
| 74 |
+
image_hw=cfg.image_hw,
|
| 75 |
+
fovy_deg=cfg.fovy_deg,
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
```
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
|
| 82 |
goal_thresh = 0.0
|
| 83 |
|
|
|
|
| 90 |
reconfiguration_freq: int = None,
|
| 91 |
**kwargs,
|
| 92 |
):
|
| 93 |
+
"""Initializes the PickEmbodiedGen environment.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
*args: Variable length argument list for the base class.
|
| 97 |
+
robot_uids: The robot(s) to use in the environment.
|
| 98 |
+
robot_init_qpos_noise: Noise added to the robot's initial joint
|
| 99 |
+
positions.
|
| 100 |
+
num_envs: The number of parallel environments to create.
|
| 101 |
+
reconfiguration_freq: How often to reconfigure the scene. If None,
|
| 102 |
+
it is set based on num_envs.
|
| 103 |
+
**kwargs: Additional keyword arguments for environment setup,
|
| 104 |
+
including layout_file, replace_objs, enable_grasp, etc.
|
| 105 |
+
"""
|
| 106 |
self.robot_init_qpos_noise = robot_init_qpos_noise
|
| 107 |
if reconfiguration_freq is None:
|
| 108 |
if num_envs == 1:
|
|
|
|
| 156 |
def init_env_layouts(
|
| 157 |
layout_file: str, num_envs: int, replace_objs: bool
|
| 158 |
) -> list[LayoutInfo]:
|
| 159 |
+
"""Initializes and saves layout files for each environment instance.
|
| 160 |
+
|
| 161 |
+
For each environment, this method creates a layout configuration. If
|
| 162 |
+
`replace_objs` is True, it generates new object placements for each
|
| 163 |
+
subsequent environment. The generated layouts are saved as new JSON
|
| 164 |
+
files.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
layout_file: Path to the base layout JSON file.
|
| 168 |
+
num_envs: The number of environments to create layouts for.
|
| 169 |
+
replace_objs: If True, generates new object placements for each
|
| 170 |
+
environment after the first one using BFS placement.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
A list of file paths to the generated layout for each environment.
|
| 174 |
+
"""
|
| 175 |
layouts = []
|
| 176 |
for env_idx in range(num_envs):
|
| 177 |
if replace_objs and env_idx > 0:
|
|
|
|
| 192 |
def compute_robot_init_pose(
|
| 193 |
layouts: list[str], num_envs: int, z_offset: float = 0.0
|
| 194 |
) -> list[list[float]]:
|
| 195 |
+
"""Computes the initial pose for the robot in each environment.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
layouts: A list of file paths to the environment layouts.
|
| 199 |
+
num_envs: The number of environments.
|
| 200 |
+
z_offset: An optional vertical offset to apply to the robot's
|
| 201 |
+
position to prevent collisions.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot
|
| 205 |
+
in each environment.
|
| 206 |
+
"""
|
| 207 |
robot_pose = []
|
| 208 |
for env_idx in range(num_envs):
|
| 209 |
layout = json.load(open(layouts[env_idx], "r"))
|
|
|
|
| 216 |
|
| 217 |
@property
|
| 218 |
def _default_sim_config(self):
|
| 219 |
+
"""Returns the default simulation configuration.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
The default simulation configuration object.
|
| 223 |
+
"""
|
| 224 |
return SimConfig(
|
| 225 |
scene_config=SceneConfig(
|
| 226 |
solver_position_iterations=30,
|
|
|
|
| 236 |
|
| 237 |
@property
|
| 238 |
def _default_sensor_configs(self):
|
| 239 |
+
"""Returns the default sensor configurations for the agent.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
A list containing the default camera configuration.
|
| 243 |
+
"""
|
| 244 |
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
|
| 245 |
|
| 246 |
return [
|
|
|
|
| 249 |
|
| 250 |
@property
|
| 251 |
def _default_human_render_camera_configs(self):
|
| 252 |
+
"""Returns the default camera configuration for human-friendly rendering.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
The default camera configuration for the renderer.
|
| 256 |
+
"""
|
| 257 |
pose = sapien_utils.look_at(
|
| 258 |
eye=self.camera_cfg["camera_eye"],
|
| 259 |
target=self.camera_cfg["camera_target_pt"],
|
|
|
|
| 270 |
)
|
| 271 |
|
| 272 |
def _load_agent(self, options: dict):
|
| 273 |
+
"""Loads the agent (robot) and a ground plane into the scene.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
options: A dictionary of options for loading the agent.
|
| 277 |
+
"""
|
| 278 |
self.ground = build_ground(self.scene)
|
| 279 |
super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
|
| 280 |
|
| 281 |
def _load_scene(self, options: dict):
|
| 282 |
+
"""Loads all assets, objects, and the goal site into the scene.
|
| 283 |
+
|
| 284 |
+
This method iterates through the layouts for each environment, loads the
|
| 285 |
+
specified assets, and adds them to the simulation. It also creates a
|
| 286 |
+
kinematic sphere to represent the goal site.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
options: A dictionary of options for loading the scene.
|
| 290 |
+
"""
|
| 291 |
all_objects = []
|
| 292 |
logger.info(f"Loading EmbodiedGen assets...")
|
| 293 |
for env_idx in range(self.num_envs):
|
|
|
|
| 319 |
self._hidden_objects.append(self.goal_site)
|
| 320 |
|
| 321 |
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
|
| 322 |
+
"""Initializes an episode for a given set of environments.
|
| 323 |
+
|
| 324 |
+
This method sets the goal position, resets the robot's joint positions
|
| 325 |
+
with optional noise, and sets its root pose.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
env_idx: A tensor of environment indices to initialize.
|
| 329 |
+
options: A dictionary of options for initialization.
|
| 330 |
+
"""
|
| 331 |
with torch.device(self.device):
|
| 332 |
b = len(env_idx)
|
| 333 |
goal_xyz = torch.zeros((b, 3))
|
|
|
|
| 362 |
def render_gs3d_images(
|
| 363 |
self, layouts: list[str], num_envs: int, init_quat: list[float]
|
| 364 |
) -> dict[str, np.ndarray]:
|
| 365 |
+
"""Renders background images using a pre-trained Gaussian Splatting model.
|
| 366 |
+
|
| 367 |
+
This method pre-renders the static background for each environment from
|
| 368 |
+
the perspective of all cameras to be used for hybrid rendering.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
layouts: A list of file paths to the environment layouts.
|
| 372 |
+
num_envs: The number of environments.
|
| 373 |
+
init_quat: An initial quaternion to orient the Gaussian Splatting
|
| 374 |
+
model.
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
A dictionary mapping a unique key (e.g., 'camera-env_idx') to the
|
| 378 |
+
rendered background image as a numpy array.
|
| 379 |
+
"""
|
| 380 |
sim_coord_align = (
|
| 381 |
torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
|
| 382 |
)
|
|
|
|
| 414 |
return bg_images
|
| 415 |
|
| 416 |
def render(self):
|
| 417 |
+
"""Renders the environment based on the configured render_mode.
|
| 418 |
+
|
| 419 |
+
Raises:
|
| 420 |
+
RuntimeError: If `render_mode` is not set.
|
| 421 |
+
NotImplementedError: If the `render_mode` is not supported.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
The rendered output, which varies depending on the render mode.
|
| 425 |
+
"""
|
| 426 |
if self.render_mode is None:
|
| 427 |
raise RuntimeError("render_mode is not set.")
|
| 428 |
if self.render_mode == "human":
|
|
|
|
| 445 |
def render_rgb_array(
|
| 446 |
self, camera_name: str = None, return_alpha: bool = False
|
| 447 |
):
|
| 448 |
+
"""Renders an RGB image from the human-facing render camera.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
camera_name: The name of the camera to render from. If None, uses
|
| 452 |
+
all human render cameras.
|
| 453 |
+
return_alpha: Whether to include the alpha channel in the output.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
A numpy array representing the rendered image(s). If multiple
|
| 457 |
+
cameras are used, the images are tiled.
|
| 458 |
+
"""
|
| 459 |
for obj in self._hidden_objects:
|
| 460 |
obj.show_visual()
|
| 461 |
self.scene.update_render(
|
|
|
|
| 476 |
return tile_images(images)
|
| 477 |
|
| 478 |
def render_sensors(self):
|
| 479 |
+
"""Renders images from all on-board sensor cameras.
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
A tiled image of all sensor outputs as a numpy array.
|
| 483 |
+
"""
|
| 484 |
images = []
|
| 485 |
sensor_images = self.get_sensor_images()
|
| 486 |
for image in sensor_images.values():
|
|
|
|
| 489 |
return tile_images(images)
|
| 490 |
|
| 491 |
def hybrid_render(self):
|
| 492 |
+
"""Renders a hybrid image by blending simulated foreground with a background.
|
| 493 |
+
|
| 494 |
+
The foreground is rendered with an alpha channel and then blended with
|
| 495 |
+
the pre-rendered Gaussian Splatting background image.
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
A torch tensor of the final blended RGB images.
|
| 499 |
+
"""
|
| 500 |
fg_images = self.render_rgb_array(
|
| 501 |
return_alpha=True
|
| 502 |
) # (n_env, h, w, 3)
|
|
|
|
| 516 |
return images[..., :3]
|
| 517 |
|
| 518 |
def evaluate(self):
|
| 519 |
+
"""Evaluates the current state of the environment.
|
| 520 |
+
|
| 521 |
+
Checks for task success criteria such as whether the object is grasped,
|
| 522 |
+
placed at the goal, and if the robot is static.
|
| 523 |
+
|
| 524 |
+
Returns:
|
| 525 |
+
A dictionary containing boolean tensors for various success
|
| 526 |
+
metrics, including 'is_grasped', 'is_obj_placed', and overall
|
| 527 |
+
'success'.
|
| 528 |
+
"""
|
| 529 |
obj_to_goal_pos = (
|
| 530 |
self.obj.pose.p
|
| 531 |
) # self.goal_site.pose.p - self.obj.pose.p
|
|
|
|
| 545 |
)
|
| 546 |
|
| 547 |
def _get_obs_extra(self, info: dict):
|
| 548 |
+
"""Gets extra information for the observation dictionary.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
info: A dictionary containing evaluation information.
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
An empty dictionary, as no extra observations are added.
|
| 555 |
+
"""
|
| 556 |
|
| 557 |
return dict()
|
| 558 |
|
| 559 |
def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
|
| 560 |
+
"""Computes a dense reward for the current step.
|
| 561 |
+
|
| 562 |
+
The reward is a composite of reaching, grasping, placing, and
|
| 563 |
+
maintaining a static final pose.
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
obs: The current observation.
|
| 567 |
+
action: The action taken in the current step.
|
| 568 |
+
info: A dictionary containing evaluation information from `evaluate()`.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
A tensor containing the dense reward for each environment.
|
| 572 |
+
"""
|
| 573 |
tcp_to_obj_dist = torch.linalg.norm(
|
| 574 |
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
|
| 575 |
)
|
|
|
|
| 602 |
def compute_normalized_dense_reward(
|
| 603 |
self, obs: any, action: torch.Tensor, info: dict
|
| 604 |
):
|
| 605 |
+
"""Computes a dense reward normalized to be between 0 and 1.
|
| 606 |
+
|
| 607 |
+
Args:
|
| 608 |
+
obs: The current observation.
|
| 609 |
+
action: The action taken in the current step.
|
| 610 |
+
info: A dictionary containing evaluation information from `evaluate()`.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
A tensor containing the normalized dense reward for each environment.
|
| 614 |
+
"""
|
| 615 |
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
|
embodied_gen/models/delight_model.py
CHANGED
|
@@ -40,7 +40,7 @@ class DelightingModel(object):
|
|
| 40 |
"""A model to remove the lighting in image space.
|
| 41 |
|
| 42 |
This model is encapsulated based on the Hunyuan3D-Delight model
|
| 43 |
-
from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
|
| 44 |
|
| 45 |
Attributes:
|
| 46 |
image_guide_scale (float): Weight of image guidance in diffusion process.
|
|
|
|
| 40 |
"""A model to remove the lighting in image space.
|
| 41 |
|
| 42 |
This model is encapsulated based on the Hunyuan3D-Delight model
|
| 43 |
+
from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa
|
| 44 |
|
| 45 |
Attributes:
|
| 46 |
image_guide_scale (float): Weight of image guidance in diffusion process.
|
embodied_gen/models/gs_model.py
CHANGED
|
@@ -21,14 +21,18 @@ import struct
|
|
| 21 |
from dataclasses import dataclass
|
| 22 |
from typing import Optional
|
| 23 |
|
| 24 |
-
import cv2
|
| 25 |
import numpy as np
|
| 26 |
import torch
|
| 27 |
from gsplat.cuda._wrapper import spherical_harmonics
|
| 28 |
from gsplat.rendering import rasterization
|
| 29 |
from plyfile import PlyData
|
| 30 |
from scipy.spatial.transform import Rotation
|
| 31 |
-
from embodied_gen.data.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
| 34 |
logger = logging.getLogger(__name__)
|
|
@@ -494,6 +498,21 @@ class GaussianOperator(GaussianBase):
|
|
| 494 |
)
|
| 495 |
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
if __name__ == "__main__":
|
| 498 |
input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
|
| 499 |
output_gs = "./gs_model.ply"
|
|
|
|
| 21 |
from dataclasses import dataclass
|
| 22 |
from typing import Optional
|
| 23 |
|
|
|
|
| 24 |
import numpy as np
|
| 25 |
import torch
|
| 26 |
from gsplat.cuda._wrapper import spherical_harmonics
|
| 27 |
from gsplat.rendering import rasterization
|
| 28 |
from plyfile import PlyData
|
| 29 |
from scipy.spatial.transform import Rotation
|
| 30 |
+
from embodied_gen.data.utils import (
|
| 31 |
+
gamma_shs,
|
| 32 |
+
normalize_vertices_array,
|
| 33 |
+
quat_mult,
|
| 34 |
+
quat_to_rotmat,
|
| 35 |
+
)
|
| 36 |
|
| 37 |
logging.basicConfig(level=logging.INFO)
|
| 38 |
logger = logging.getLogger(__name__)
|
|
|
|
| 498 |
)
|
| 499 |
|
| 500 |
|
| 501 |
+
def load_gs_model(
|
| 502 |
+
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
|
| 503 |
+
) -> GaussianOperator:
|
| 504 |
+
gs_model = GaussianOperator.load_from_ply(input_gs)
|
| 505 |
+
# Normalize vertices to [-1, 1], center to (0, 0, 0).
|
| 506 |
+
_, scale, center = normalize_vertices_array(gs_model._means)
|
| 507 |
+
scale, center = float(scale), center.tolist()
|
| 508 |
+
transpose = [*[v for v in center], *pre_quat]
|
| 509 |
+
instance_pose = torch.tensor(transpose).to(gs_model.device)
|
| 510 |
+
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
|
| 511 |
+
gs_model.rescale(scale)
|
| 512 |
+
|
| 513 |
+
return gs_model
|
| 514 |
+
|
| 515 |
+
|
| 516 |
if __name__ == "__main__":
|
| 517 |
input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
|
| 518 |
output_gs = "./gs_model.ply"
|
embodied_gen/models/image_comm_model.py
CHANGED
|
@@ -38,26 +38,61 @@ __all__ = [
|
|
| 38 |
|
| 39 |
|
| 40 |
class BasePipelineLoader(ABC):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def __init__(self, device="cuda"):
|
| 42 |
self.device = device
|
| 43 |
|
| 44 |
@abstractmethod
|
| 45 |
def load(self):
|
|
|
|
| 46 |
pass
|
| 47 |
|
| 48 |
|
| 49 |
class BasePipelineRunner(ABC):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def __init__(self, pipe):
|
| 51 |
self.pipe = pipe
|
| 52 |
|
| 53 |
@abstractmethod
|
| 54 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
pass
|
| 56 |
|
| 57 |
|
| 58 |
# ===== SD3.5-medium =====
|
| 59 |
class SD35Loader(BasePipelineLoader):
|
|
|
|
|
|
|
| 60 |
def load(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
pipe = StableDiffusion3Pipeline.from_pretrained(
|
| 62 |
"stabilityai/stable-diffusion-3.5-medium",
|
| 63 |
torch_dtype=torch.float16,
|
|
@@ -70,12 +105,25 @@ class SD35Loader(BasePipelineLoader):
|
|
| 70 |
|
| 71 |
|
| 72 |
class SD35Runner(BasePipelineRunner):
|
|
|
|
|
|
|
| 73 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return self.pipe(prompt=prompt, **kwargs).images
|
| 75 |
|
| 76 |
|
| 77 |
# ===== Cosmos2 =====
|
| 78 |
class CosmosLoader(BasePipelineLoader):
|
|
|
|
|
|
|
| 79 |
def __init__(
|
| 80 |
self,
|
| 81 |
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
|
@@ -87,6 +135,8 @@ class CosmosLoader(BasePipelineLoader):
|
|
| 87 |
self.local_dir = local_dir
|
| 88 |
|
| 89 |
def _patch(self):
|
|
|
|
|
|
|
| 90 |
def patch_model(cls):
|
| 91 |
orig = cls.from_pretrained
|
| 92 |
|
|
@@ -110,6 +160,11 @@ class CosmosLoader(BasePipelineLoader):
|
|
| 110 |
patch_processor(SiglipProcessor)
|
| 111 |
|
| 112 |
def load(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
self._patch()
|
| 114 |
snapshot_download(
|
| 115 |
repo_id=self.model_id,
|
|
@@ -141,7 +196,19 @@ class CosmosLoader(BasePipelineLoader):
|
|
| 141 |
|
| 142 |
|
| 143 |
class CosmosRunner(BasePipelineRunner):
|
|
|
|
|
|
|
| 144 |
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
return self.pipe(
|
| 146 |
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
| 147 |
).images
|
|
@@ -149,7 +216,14 @@ class CosmosRunner(BasePipelineRunner):
|
|
| 149 |
|
| 150 |
# ===== Kolors =====
|
| 151 |
class KolorsLoader(BasePipelineLoader):
|
|
|
|
|
|
|
| 152 |
def load(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
pipe = KolorsPipeline.from_pretrained(
|
| 154 |
"Kwai-Kolors/Kolors-diffusers",
|
| 155 |
torch_dtype=torch.float16,
|
|
@@ -164,13 +238,31 @@ class KolorsLoader(BasePipelineLoader):
|
|
| 164 |
|
| 165 |
|
| 166 |
class KolorsRunner(BasePipelineRunner):
|
|
|
|
|
|
|
| 167 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
return self.pipe(prompt=prompt, **kwargs).images
|
| 169 |
|
| 170 |
|
| 171 |
# ===== Flux =====
|
| 172 |
class FluxLoader(BasePipelineLoader):
|
|
|
|
|
|
|
| 173 |
def load(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 175 |
pipe = FluxPipeline.from_pretrained(
|
| 176 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
|
@@ -182,20 +274,50 @@ class FluxLoader(BasePipelineLoader):
|
|
| 182 |
|
| 183 |
|
| 184 |
class FluxRunner(BasePipelineRunner):
|
|
|
|
|
|
|
| 185 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
return self.pipe(prompt=prompt, **kwargs).images
|
| 187 |
|
| 188 |
|
| 189 |
# ===== Chroma =====
|
| 190 |
class ChromaLoader(BasePipelineLoader):
|
|
|
|
|
|
|
| 191 |
def load(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
return ChromaPipeline.from_pretrained(
|
| 193 |
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
| 194 |
).to(self.device)
|
| 195 |
|
| 196 |
|
| 197 |
class ChromaRunner(BasePipelineRunner):
|
|
|
|
|
|
|
| 198 |
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
return self.pipe(
|
| 200 |
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
| 201 |
).images
|
|
@@ -211,6 +333,22 @@ PIPELINE_REGISTRY = {
|
|
| 211 |
|
| 212 |
|
| 213 |
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
if name not in PIPELINE_REGISTRY:
|
| 215 |
raise ValueError(f"Unsupported model: {name}")
|
| 216 |
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class BasePipelineLoader(ABC):
|
| 41 |
+
"""Abstract base class for loading Hugging Face image generation pipelines.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
device (str): Device to load the pipeline on.
|
| 45 |
+
|
| 46 |
+
Methods:
|
| 47 |
+
load(): Loads and returns the pipeline.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
def __init__(self, device="cuda"):
|
| 51 |
self.device = device
|
| 52 |
|
| 53 |
@abstractmethod
|
| 54 |
def load(self):
|
| 55 |
+
"""Load and return the pipeline instance."""
|
| 56 |
pass
|
| 57 |
|
| 58 |
|
| 59 |
class BasePipelineRunner(ABC):
|
| 60 |
+
"""Abstract base class for running image generation pipelines.
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
pipe: The loaded pipeline.
|
| 64 |
+
|
| 65 |
+
Methods:
|
| 66 |
+
run(prompt, **kwargs): Runs the pipeline with a prompt.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
def __init__(self, pipe):
|
| 70 |
self.pipe = pipe
|
| 71 |
|
| 72 |
@abstractmethod
|
| 73 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
| 74 |
+
"""Run the pipeline with the given prompt.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
prompt (str): Text prompt for image generation.
|
| 78 |
+
**kwargs: Additional pipeline arguments.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Image.Image: Generated image(s).
|
| 82 |
+
"""
|
| 83 |
pass
|
| 84 |
|
| 85 |
|
| 86 |
# ===== SD3.5-medium =====
|
| 87 |
class SD35Loader(BasePipelineLoader):
|
| 88 |
+
"""Loader for Stable Diffusion 3.5 medium pipeline."""
|
| 89 |
+
|
| 90 |
def load(self):
|
| 91 |
+
"""Load the Stable Diffusion 3.5 medium pipeline.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
StableDiffusion3Pipeline: Loaded pipeline.
|
| 95 |
+
"""
|
| 96 |
pipe = StableDiffusion3Pipeline.from_pretrained(
|
| 97 |
"stabilityai/stable-diffusion-3.5-medium",
|
| 98 |
torch_dtype=torch.float16,
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
class SD35Runner(BasePipelineRunner):
|
| 108 |
+
"""Runner for Stable Diffusion 3.5 medium pipeline."""
|
| 109 |
+
|
| 110 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
| 111 |
+
"""Generate images using Stable Diffusion 3.5 medium.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
prompt (str): Text prompt.
|
| 115 |
+
**kwargs: Additional arguments.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Image.Image: Generated image(s).
|
| 119 |
+
"""
|
| 120 |
return self.pipe(prompt=prompt, **kwargs).images
|
| 121 |
|
| 122 |
|
| 123 |
# ===== Cosmos2 =====
|
| 124 |
class CosmosLoader(BasePipelineLoader):
|
| 125 |
+
"""Loader for Cosmos2 text-to-image pipeline."""
|
| 126 |
+
|
| 127 |
def __init__(
|
| 128 |
self,
|
| 129 |
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
|
|
|
| 135 |
self.local_dir = local_dir
|
| 136 |
|
| 137 |
def _patch(self):
|
| 138 |
+
"""Patch model and processor for optimized loading."""
|
| 139 |
+
|
| 140 |
def patch_model(cls):
|
| 141 |
orig = cls.from_pretrained
|
| 142 |
|
|
|
|
| 160 |
patch_processor(SiglipProcessor)
|
| 161 |
|
| 162 |
def load(self):
|
| 163 |
+
"""Load the Cosmos2 text-to-image pipeline.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Cosmos2TextToImagePipeline: Loaded pipeline.
|
| 167 |
+
"""
|
| 168 |
self._patch()
|
| 169 |
snapshot_download(
|
| 170 |
repo_id=self.model_id,
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
class CosmosRunner(BasePipelineRunner):
|
| 199 |
+
"""Runner for Cosmos2 text-to-image pipeline."""
|
| 200 |
+
|
| 201 |
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
| 202 |
+
"""Generate images using Cosmos2 pipeline.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
prompt (str): Text prompt.
|
| 206 |
+
negative_prompt (str, optional): Negative prompt.
|
| 207 |
+
**kwargs: Additional arguments.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Image.Image: Generated image(s).
|
| 211 |
+
"""
|
| 212 |
return self.pipe(
|
| 213 |
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
| 214 |
).images
|
|
|
|
| 216 |
|
| 217 |
# ===== Kolors =====
|
| 218 |
class KolorsLoader(BasePipelineLoader):
|
| 219 |
+
"""Loader for Kolors pipeline."""
|
| 220 |
+
|
| 221 |
def load(self):
|
| 222 |
+
"""Load the Kolors pipeline.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
KolorsPipeline: Loaded pipeline.
|
| 226 |
+
"""
|
| 227 |
pipe = KolorsPipeline.from_pretrained(
|
| 228 |
"Kwai-Kolors/Kolors-diffusers",
|
| 229 |
torch_dtype=torch.float16,
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
class KolorsRunner(BasePipelineRunner):
|
| 241 |
+
"""Runner for Kolors pipeline."""
|
| 242 |
+
|
| 243 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
| 244 |
+
"""Generate images using Kolors pipeline.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
prompt (str): Text prompt.
|
| 248 |
+
**kwargs: Additional arguments.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Image.Image: Generated image(s).
|
| 252 |
+
"""
|
| 253 |
return self.pipe(prompt=prompt, **kwargs).images
|
| 254 |
|
| 255 |
|
| 256 |
# ===== Flux =====
|
| 257 |
class FluxLoader(BasePipelineLoader):
|
| 258 |
+
"""Loader for Flux pipeline."""
|
| 259 |
+
|
| 260 |
def load(self):
|
| 261 |
+
"""Load the Flux pipeline.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
FluxPipeline: Loaded pipeline.
|
| 265 |
+
"""
|
| 266 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 267 |
pipe = FluxPipeline.from_pretrained(
|
| 268 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
class FluxRunner(BasePipelineRunner):
|
| 277 |
+
"""Runner for Flux pipeline."""
|
| 278 |
+
|
| 279 |
def run(self, prompt: str, **kwargs) -> Image.Image:
|
| 280 |
+
"""Generate images using Flux pipeline.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
prompt (str): Text prompt.
|
| 284 |
+
**kwargs: Additional arguments.
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
Image.Image: Generated image(s).
|
| 288 |
+
"""
|
| 289 |
return self.pipe(prompt=prompt, **kwargs).images
|
| 290 |
|
| 291 |
|
| 292 |
# ===== Chroma =====
|
| 293 |
class ChromaLoader(BasePipelineLoader):
|
| 294 |
+
"""Loader for Chroma pipeline."""
|
| 295 |
+
|
| 296 |
def load(self):
|
| 297 |
+
"""Load the Chroma pipeline.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
ChromaPipeline: Loaded pipeline.
|
| 301 |
+
"""
|
| 302 |
return ChromaPipeline.from_pretrained(
|
| 303 |
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
| 304 |
).to(self.device)
|
| 305 |
|
| 306 |
|
| 307 |
class ChromaRunner(BasePipelineRunner):
|
| 308 |
+
"""Runner for Chroma pipeline."""
|
| 309 |
+
|
| 310 |
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
| 311 |
+
"""Generate images using Chroma pipeline.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
prompt (str): Text prompt.
|
| 315 |
+
negative_prompt (str, optional): Negative prompt.
|
| 316 |
+
**kwargs: Additional arguments.
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Image.Image: Generated image(s).
|
| 320 |
+
"""
|
| 321 |
return self.pipe(
|
| 322 |
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
| 323 |
).images
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
| 336 |
+
"""Build a Hugging Face image generation pipeline runner by name.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
name (str): Name of the pipeline (e.g., "sd35", "cosmos").
|
| 340 |
+
device (str): Device to load the pipeline on.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
BasePipelineRunner: Pipeline runner instance.
|
| 344 |
+
|
| 345 |
+
Example:
|
| 346 |
+
```py
|
| 347 |
+
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
| 348 |
+
runner = build_hf_image_pipeline("sd35")
|
| 349 |
+
images = runner.run(prompt="A robot holding a sign that says 'Hello'")
|
| 350 |
+
```
|
| 351 |
+
"""
|
| 352 |
if name not in PIPELINE_REGISTRY:
|
| 353 |
raise ValueError(f"Unsupported model: {name}")
|
| 354 |
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
embodied_gen/models/layout.py
CHANGED
|
@@ -376,6 +376,21 @@ LAYOUT_DESCRIBER_PROMPT = """
|
|
| 376 |
|
| 377 |
|
| 378 |
class LayoutDesigner(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
def __init__(
|
| 380 |
self,
|
| 381 |
gpt_client: GPTclient,
|
|
@@ -387,6 +402,15 @@ class LayoutDesigner(object):
|
|
| 387 |
self.gpt_client = gpt_client
|
| 388 |
|
| 389 |
def query(self, prompt: str, params: dict = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
| 391 |
|
| 392 |
response = self.gpt_client.query(
|
|
@@ -400,6 +424,17 @@ class LayoutDesigner(object):
|
|
| 400 |
return response
|
| 401 |
|
| 402 |
def format_response(self, response: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
| 404 |
try:
|
| 405 |
output = json.loads(cleaned)
|
|
@@ -411,9 +446,23 @@ class LayoutDesigner(object):
|
|
| 411 |
return output
|
| 412 |
|
| 413 |
def format_response_repair(self, response: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
return json_repair.loads(response)
|
| 415 |
|
| 416 |
def save_output(self, output: dict, save_path: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 418 |
with open(save_path, 'w') as f:
|
| 419 |
json.dump(output, f, indent=4)
|
|
@@ -421,6 +470,16 @@ class LayoutDesigner(object):
|
|
| 421 |
def __call__(
|
| 422 |
self, prompt: str, save_path: str = None, params: dict = None
|
| 423 |
) -> dict | str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
response = self.query(prompt, params=params)
|
| 425 |
output = self.format_response_repair(response)
|
| 426 |
self.save_output(output, save_path) if save_path else None
|
|
@@ -442,6 +501,29 @@ LAYOUT_DESCRIBER = LayoutDesigner(
|
|
| 442 |
def build_scene_layout(
|
| 443 |
task_desc: str, output_path: str = None, gpt_params: dict = None
|
| 444 |
) -> LayoutInfo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
| 446 |
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
| 447 |
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
class LayoutDesigner(object):
|
| 379 |
+
"""A class for querying GPT-based scene layout reasoning and formatting responses.
|
| 380 |
+
|
| 381 |
+
Attributes:
|
| 382 |
+
prompt (str): The system prompt for GPT.
|
| 383 |
+
verbose (bool): Whether to log responses.
|
| 384 |
+
gpt_client (GPTclient): The GPT client instance.
|
| 385 |
+
|
| 386 |
+
Methods:
|
| 387 |
+
query(prompt, params): Query GPT with a prompt and parameters.
|
| 388 |
+
format_response(response): Parse and clean JSON response.
|
| 389 |
+
format_response_repair(response): Repair and parse JSON response.
|
| 390 |
+
save_output(output, save_path): Save output to file.
|
| 391 |
+
__call__(prompt, save_path, params): Query and process output.
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
def __init__(
|
| 395 |
self,
|
| 396 |
gpt_client: GPTclient,
|
|
|
|
| 402 |
self.gpt_client = gpt_client
|
| 403 |
|
| 404 |
def query(self, prompt: str, params: dict = None) -> str:
|
| 405 |
+
"""Query GPT with the system prompt and user prompt.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
prompt (str): User prompt.
|
| 409 |
+
params (dict, optional): GPT parameters.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
str: GPT response.
|
| 413 |
+
"""
|
| 414 |
full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
|
| 415 |
|
| 416 |
response = self.gpt_client.query(
|
|
|
|
| 424 |
return response
|
| 425 |
|
| 426 |
def format_response(self, response: str) -> dict:
|
| 427 |
+
"""Format and parse GPT response as JSON.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
response (str): Raw GPT response.
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
dict: Parsed JSON output.
|
| 434 |
+
|
| 435 |
+
Raises:
|
| 436 |
+
json.JSONDecodeError: If parsing fails.
|
| 437 |
+
"""
|
| 438 |
cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
|
| 439 |
try:
|
| 440 |
output = json.loads(cleaned)
|
|
|
|
| 446 |
return output
|
| 447 |
|
| 448 |
def format_response_repair(self, response: str) -> dict:
|
| 449 |
+
"""Repair and parse possibly broken JSON response.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
response (str): Raw GPT response.
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
dict: Parsed JSON output.
|
| 456 |
+
"""
|
| 457 |
return json_repair.loads(response)
|
| 458 |
|
| 459 |
def save_output(self, output: dict, save_path: str) -> None:
|
| 460 |
+
"""Save output dictionary to a file.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
output (dict): Output data.
|
| 464 |
+
save_path (str): Path to save the file.
|
| 465 |
+
"""
|
| 466 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 467 |
with open(save_path, 'w') as f:
|
| 468 |
json.dump(output, f, indent=4)
|
|
|
|
| 470 |
def __call__(
|
| 471 |
self, prompt: str, save_path: str = None, params: dict = None
|
| 472 |
) -> dict | str:
|
| 473 |
+
"""Query GPT and process the output.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
prompt (str): User prompt.
|
| 477 |
+
save_path (str, optional): Path to save output.
|
| 478 |
+
params (dict, optional): GPT parameters.
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
dict | str: Output data.
|
| 482 |
+
"""
|
| 483 |
response = self.query(prompt, params=params)
|
| 484 |
output = self.format_response_repair(response)
|
| 485 |
self.save_output(output, save_path) if save_path else None
|
|
|
|
| 501 |
def build_scene_layout(
|
| 502 |
task_desc: str, output_path: str = None, gpt_params: dict = None
|
| 503 |
) -> LayoutInfo:
|
| 504 |
+
"""Build a 3D scene layout from a natural language task description.
|
| 505 |
+
|
| 506 |
+
This function uses GPT-based reasoning to generate a structured scene layout,
|
| 507 |
+
including object hierarchy, spatial relations, and style descriptions.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
task_desc (str): Natural language description of the robotic task.
|
| 511 |
+
output_path (str, optional): Path to save the visualized scene tree.
|
| 512 |
+
gpt_params (dict, optional): Parameters for GPT queries.
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
LayoutInfo: Structured layout information for the scene.
|
| 516 |
+
|
| 517 |
+
Example:
|
| 518 |
+
```py
|
| 519 |
+
from embodied_gen.models.layout import build_scene_layout
|
| 520 |
+
layout_info = build_scene_layout(
|
| 521 |
+
task_desc="Put the apples on the table on the plate",
|
| 522 |
+
output_path="outputs/scene_tree.jpg",
|
| 523 |
+
)
|
| 524 |
+
print(layout_info)
|
| 525 |
+
```
|
| 526 |
+
"""
|
| 527 |
layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
|
| 528 |
layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
|
| 529 |
object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
|
embodied_gen/models/segment_model.py
CHANGED
|
@@ -48,12 +48,19 @@ __all__ = [
|
|
| 48 |
|
| 49 |
|
| 50 |
class SAMRemover(object):
|
| 51 |
-
"""
|
| 52 |
|
| 53 |
Attributes:
|
| 54 |
checkpoint (str): Path to the model checkpoint.
|
| 55 |
-
model_type (str): Type of the SAM model to load
|
| 56 |
-
area_ratio (float): Area ratio filtering small connected components.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
"""
|
| 58 |
|
| 59 |
def __init__(
|
|
@@ -78,6 +85,14 @@ class SAMRemover(object):
|
|
| 78 |
self.mask_generator = self._load_sam_model(checkpoint)
|
| 79 |
|
| 80 |
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
| 82 |
sam.to(device=self.device)
|
| 83 |
|
|
@@ -89,13 +104,11 @@ class SAMRemover(object):
|
|
| 89 |
"""Removes the background from an image using the SAM model.
|
| 90 |
|
| 91 |
Args:
|
| 92 |
-
image (Union[str, Image.Image, np.ndarray]): Input image
|
| 93 |
-
|
| 94 |
-
save_path (str): Path to save the output image (default: None).
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
-
Image.Image:
|
| 98 |
-
including an alpha channel.
|
| 99 |
"""
|
| 100 |
# Convert input to numpy array
|
| 101 |
if isinstance(image, str):
|
|
@@ -134,6 +147,15 @@ class SAMRemover(object):
|
|
| 134 |
|
| 135 |
|
| 136 |
class SAMPredictor(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def __init__(
|
| 138 |
self,
|
| 139 |
checkpoint: str = None,
|
|
@@ -157,12 +179,28 @@ class SAMPredictor(object):
|
|
| 157 |
self.binary_thresh = binary_thresh
|
| 158 |
|
| 159 |
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
| 161 |
sam.to(device=self.device)
|
| 162 |
|
| 163 |
return SamPredictor(sam)
|
| 164 |
|
| 165 |
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
if isinstance(image, str):
|
| 167 |
image = Image.open(image)
|
| 168 |
elif isinstance(image, np.ndarray):
|
|
@@ -178,6 +216,15 @@ class SAMPredictor(object):
|
|
| 178 |
image: np.ndarray,
|
| 179 |
selected_points: list[list[int]],
|
| 180 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if len(selected_points) == 0:
|
| 182 |
return []
|
| 183 |
|
|
@@ -220,6 +267,15 @@ class SAMPredictor(object):
|
|
| 220 |
def get_segmented_image(
|
| 221 |
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
| 222 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
seg_image = Image.fromarray(image, mode="RGB")
|
| 224 |
alpha_channel = np.zeros(
|
| 225 |
(seg_image.height, seg_image.width), dtype=np.uint8
|
|
@@ -241,6 +297,15 @@ class SAMPredictor(object):
|
|
| 241 |
image: Union[str, Image.Image, np.ndarray],
|
| 242 |
selected_points: list[list[int]],
|
| 243 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
image = self.preprocess_image(image)
|
| 245 |
self.predictor.set_image(image)
|
| 246 |
masks = self.generate_masks(image, selected_points)
|
|
@@ -249,12 +314,32 @@ class SAMPredictor(object):
|
|
| 249 |
|
| 250 |
|
| 251 |
class RembgRemover(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
def __init__(self):
|
|
|
|
| 253 |
self.rembg_session = rembg.new_session("u2net")
|
| 254 |
|
| 255 |
def __call__(
|
| 256 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 257 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
if isinstance(image, str):
|
| 259 |
image = Image.open(image)
|
| 260 |
elif isinstance(image, np.ndarray):
|
|
@@ -271,7 +356,18 @@ class RembgRemover(object):
|
|
| 271 |
|
| 272 |
|
| 273 |
class BMGG14Remover(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
def __init__(self) -> None:
|
|
|
|
| 275 |
self.model = pipeline(
|
| 276 |
"image-segmentation",
|
| 277 |
model="briaai/RMBG-1.4",
|
|
@@ -281,6 +377,15 @@ class BMGG14Remover(object):
|
|
| 281 |
def __call__(
|
| 282 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 283 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
if isinstance(image, str):
|
| 285 |
image = Image.open(image)
|
| 286 |
elif isinstance(image, np.ndarray):
|
|
@@ -299,6 +404,16 @@ class BMGG14Remover(object):
|
|
| 299 |
def invert_rgba_pil(
|
| 300 |
image: Image.Image, mask: Image.Image, save_path: str = None
|
| 301 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
mask = (255 - np.array(mask))[..., None]
|
| 303 |
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
| 304 |
inverted_image = Image.fromarray(image_array, "RGBA")
|
|
@@ -318,6 +433,20 @@ def get_segmented_image_by_agent(
|
|
| 318 |
save_path: str = None,
|
| 319 |
mode: Literal["loose", "strict"] = "loose",
|
| 320 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
| 322 |
if seg_checker is None:
|
| 323 |
return True
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
class SAMRemover(object):
|
| 51 |
+
"""Loads SAM models and performs background removal on images.
|
| 52 |
|
| 53 |
Attributes:
|
| 54 |
checkpoint (str): Path to the model checkpoint.
|
| 55 |
+
model_type (str): Type of the SAM model to load.
|
| 56 |
+
area_ratio (float): Area ratio for filtering small connected components.
|
| 57 |
+
|
| 58 |
+
Example:
|
| 59 |
+
```py
|
| 60 |
+
from embodied_gen.models.segment_model import SAMRemover
|
| 61 |
+
remover = SAMRemover(model_type="vit_h")
|
| 62 |
+
result = remover("input.jpg", "output.png")
|
| 63 |
+
```
|
| 64 |
"""
|
| 65 |
|
| 66 |
def __init__(
|
|
|
|
| 85 |
self.mask_generator = self._load_sam_model(checkpoint)
|
| 86 |
|
| 87 |
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
|
| 88 |
+
"""Loads the SAM model and returns a mask generator.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
checkpoint (str): Path to model checkpoint.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
SamAutomaticMaskGenerator: Mask generator instance.
|
| 95 |
+
"""
|
| 96 |
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
| 97 |
sam.to(device=self.device)
|
| 98 |
|
|
|
|
| 104 |
"""Removes the background from an image using the SAM model.
|
| 105 |
|
| 106 |
Args:
|
| 107 |
+
image (Union[str, Image.Image, np.ndarray]): Input image.
|
| 108 |
+
save_path (str, optional): Path to save the output image.
|
|
|
|
| 109 |
|
| 110 |
Returns:
|
| 111 |
+
Image.Image: Image with background removed (RGBA).
|
|
|
|
| 112 |
"""
|
| 113 |
# Convert input to numpy array
|
| 114 |
if isinstance(image, str):
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
class SAMPredictor(object):
|
| 150 |
+
"""Loads SAM models and predicts segmentation masks from user points.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
checkpoint (str, optional): Path to model checkpoint.
|
| 154 |
+
model_type (str, optional): SAM model type.
|
| 155 |
+
binary_thresh (float, optional): Threshold for binary mask.
|
| 156 |
+
device (str, optional): Device for inference.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
def __init__(
|
| 160 |
self,
|
| 161 |
checkpoint: str = None,
|
|
|
|
| 179 |
self.binary_thresh = binary_thresh
|
| 180 |
|
| 181 |
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
|
| 182 |
+
"""Loads the SAM model and returns a predictor.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
checkpoint (str): Path to model checkpoint.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
SamPredictor: Predictor instance.
|
| 189 |
+
"""
|
| 190 |
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
| 191 |
sam.to(device=self.device)
|
| 192 |
|
| 193 |
return SamPredictor(sam)
|
| 194 |
|
| 195 |
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
| 196 |
+
"""Preprocesses input image for SAM prediction.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
image (Image.Image): Input image.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
np.ndarray: Preprocessed image array.
|
| 203 |
+
"""
|
| 204 |
if isinstance(image, str):
|
| 205 |
image = Image.open(image)
|
| 206 |
elif isinstance(image, np.ndarray):
|
|
|
|
| 216 |
image: np.ndarray,
|
| 217 |
selected_points: list[list[int]],
|
| 218 |
) -> np.ndarray:
|
| 219 |
+
"""Generates segmentation masks from selected points.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
image (np.ndarray): Input image array.
|
| 223 |
+
selected_points (list[list[int]]): List of points and labels.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
list[tuple[np.ndarray, str]]: List of masks and names.
|
| 227 |
+
"""
|
| 228 |
if len(selected_points) == 0:
|
| 229 |
return []
|
| 230 |
|
|
|
|
| 267 |
def get_segmented_image(
|
| 268 |
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
|
| 269 |
) -> Image.Image:
|
| 270 |
+
"""Combines masks and returns segmented image with alpha channel.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
image (np.ndarray): Input image array.
|
| 274 |
+
masks (list[tuple[np.ndarray, str]]): List of masks.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
Image.Image: Segmented RGBA image.
|
| 278 |
+
"""
|
| 279 |
seg_image = Image.fromarray(image, mode="RGB")
|
| 280 |
alpha_channel = np.zeros(
|
| 281 |
(seg_image.height, seg_image.width), dtype=np.uint8
|
|
|
|
| 297 |
image: Union[str, Image.Image, np.ndarray],
|
| 298 |
selected_points: list[list[int]],
|
| 299 |
) -> Image.Image:
|
| 300 |
+
"""Segments image using selected points.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
image (Union[str, Image.Image, np.ndarray]): Input image.
|
| 304 |
+
selected_points (list[list[int]]): List of points and labels.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
Image.Image: Segmented RGBA image.
|
| 308 |
+
"""
|
| 309 |
image = self.preprocess_image(image)
|
| 310 |
self.predictor.set_image(image)
|
| 311 |
masks = self.generate_masks(image, selected_points)
|
|
|
|
| 314 |
|
| 315 |
|
| 316 |
class RembgRemover(object):
|
| 317 |
+
"""Removes background from images using the rembg library.
|
| 318 |
+
|
| 319 |
+
Example:
|
| 320 |
+
```py
|
| 321 |
+
from embodied_gen.models.segment_model import RembgRemover
|
| 322 |
+
remover = RembgRemover()
|
| 323 |
+
result = remover("input.jpg", "output.png")
|
| 324 |
+
```
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
def __init__(self):
|
| 328 |
+
"""Initializes the RembgRemover."""
|
| 329 |
self.rembg_session = rembg.new_session("u2net")
|
| 330 |
|
| 331 |
def __call__(
|
| 332 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 333 |
) -> Image.Image:
|
| 334 |
+
"""Removes background from an image.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
image (Union[str, Image.Image, np.ndarray]): Input image.
|
| 338 |
+
save_path (str, optional): Path to save the output image.
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Image.Image: Image with background removed (RGBA).
|
| 342 |
+
"""
|
| 343 |
if isinstance(image, str):
|
| 344 |
image = Image.open(image)
|
| 345 |
elif isinstance(image, np.ndarray):
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
class BMGG14Remover(object):
|
| 359 |
+
"""Removes background using the RMBG-1.4 segmentation model.
|
| 360 |
+
|
| 361 |
+
Example:
|
| 362 |
+
```py
|
| 363 |
+
from embodied_gen.models.segment_model import BMGG14Remover
|
| 364 |
+
remover = BMGG14Remover()
|
| 365 |
+
result = remover("input.jpg", "output.png")
|
| 366 |
+
```
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
def __init__(self) -> None:
|
| 370 |
+
"""Initializes the BMGG14Remover."""
|
| 371 |
self.model = pipeline(
|
| 372 |
"image-segmentation",
|
| 373 |
model="briaai/RMBG-1.4",
|
|
|
|
| 377 |
def __call__(
|
| 378 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 379 |
):
|
| 380 |
+
"""Removes background from an image.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
image (Union[str, Image.Image, np.ndarray]): Input image.
|
| 384 |
+
save_path (str, optional): Path to save the output image.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
Image.Image: Image with background removed.
|
| 388 |
+
"""
|
| 389 |
if isinstance(image, str):
|
| 390 |
image = Image.open(image)
|
| 391 |
elif isinstance(image, np.ndarray):
|
|
|
|
| 404 |
def invert_rgba_pil(
|
| 405 |
image: Image.Image, mask: Image.Image, save_path: str = None
|
| 406 |
) -> Image.Image:
|
| 407 |
+
"""Inverts the alpha channel of an RGBA image using a mask.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
image (Image.Image): Input RGB image.
|
| 411 |
+
mask (Image.Image): Mask image for alpha inversion.
|
| 412 |
+
save_path (str, optional): Path to save the output image.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
Image.Image: RGBA image with inverted alpha.
|
| 416 |
+
"""
|
| 417 |
mask = (255 - np.array(mask))[..., None]
|
| 418 |
image_array = np.concatenate([np.array(image), mask], axis=-1)
|
| 419 |
inverted_image = Image.fromarray(image_array, "RGBA")
|
|
|
|
| 433 |
save_path: str = None,
|
| 434 |
mode: Literal["loose", "strict"] = "loose",
|
| 435 |
) -> Image.Image:
|
| 436 |
+
"""Segments an image using SAM and rembg, with quality checking.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
image (Image.Image): Input image.
|
| 440 |
+
sam_remover (SAMRemover): SAM-based remover.
|
| 441 |
+
rbg_remover (RembgRemover): rembg-based remover.
|
| 442 |
+
seg_checker (ImageSegChecker, optional): Quality checker.
|
| 443 |
+
save_path (str, optional): Path to save the output image.
|
| 444 |
+
mode (Literal["loose", "strict"], optional): Segmentation mode.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Image.Image: Segmented RGBA image.
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
|
| 451 |
if seg_checker is None:
|
| 452 |
return True
|
embodied_gen/models/sr_model.py
CHANGED
|
@@ -39,13 +39,38 @@ __all__ = [
|
|
| 39 |
|
| 40 |
|
| 41 |
class ImageStableSR:
|
| 42 |
-
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def __init__(
|
| 45 |
self,
|
| 46 |
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
| 47 |
device="cuda",
|
| 48 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
from diffusers import StableDiffusionUpscalePipeline
|
| 50 |
|
| 51 |
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
|
@@ -62,6 +87,16 @@ class ImageStableSR:
|
|
| 62 |
prompt: str = "",
|
| 63 |
infer_step: int = 20,
|
| 64 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
if isinstance(image, np.ndarray):
|
| 66 |
image = Image.fromarray(image)
|
| 67 |
|
|
@@ -86,9 +121,26 @@ class ImageRealESRGAN:
|
|
| 86 |
Attributes:
|
| 87 |
outscale (int): The output image scale factor (e.g., 2, 4).
|
| 88 |
model_path (str): Path to the pre-trained model weights.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
"""
|
| 90 |
|
| 91 |
def __init__(self, outscale: int, model_path: str = None) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# monkey patch to support torchvision>=0.16
|
| 93 |
import torchvision
|
| 94 |
from packaging import version
|
|
@@ -122,6 +174,7 @@ class ImageRealESRGAN:
|
|
| 122 |
self.model_path = model_path
|
| 123 |
|
| 124 |
def _lazy_init(self):
|
|
|
|
| 125 |
if self.upsampler is None:
|
| 126 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 127 |
from realesrgan import RealESRGANer
|
|
@@ -145,6 +198,14 @@ class ImageRealESRGAN:
|
|
| 145 |
|
| 146 |
@spaces.GPU
|
| 147 |
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
self._lazy_init()
|
| 149 |
|
| 150 |
if isinstance(image, Image.Image):
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
class ImageStableSR:
|
| 42 |
+
"""Super-resolution image upscaler using Stable Diffusion x4 upscaling model.
|
| 43 |
+
|
| 44 |
+
This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality
|
| 45 |
+
image super-resolution.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model_path (str, optional): Path or HuggingFace repo for the model.
|
| 49 |
+
device (str, optional): Device for inference.
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
```py
|
| 53 |
+
from embodied_gen.models.sr_model import ImageStableSR
|
| 54 |
+
from PIL import Image
|
| 55 |
+
|
| 56 |
+
sr_model = ImageStableSR()
|
| 57 |
+
img = Image.open("input.png")
|
| 58 |
+
upscaled = sr_model(img)
|
| 59 |
+
upscaled.save("output.png")
|
| 60 |
+
```
|
| 61 |
+
"""
|
| 62 |
|
| 63 |
def __init__(
|
| 64 |
self,
|
| 65 |
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
|
| 66 |
device="cuda",
|
| 67 |
) -> None:
|
| 68 |
+
"""Initializes the Stable Diffusion x4 upscaler.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
model_path (str, optional): Model path or repo.
|
| 72 |
+
device (str, optional): Device for inference.
|
| 73 |
+
"""
|
| 74 |
from diffusers import StableDiffusionUpscalePipeline
|
| 75 |
|
| 76 |
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
|
|
|
|
| 87 |
prompt: str = "",
|
| 88 |
infer_step: int = 20,
|
| 89 |
) -> Image.Image:
|
| 90 |
+
"""Performs super-resolution on the input image.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
image (Union[Image.Image, np.ndarray]): Input image.
|
| 94 |
+
prompt (str, optional): Text prompt for upscaling.
|
| 95 |
+
infer_step (int, optional): Number of inference steps.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Image.Image: Upscaled image.
|
| 99 |
+
"""
|
| 100 |
if isinstance(image, np.ndarray):
|
| 101 |
image = Image.fromarray(image)
|
| 102 |
|
|
|
|
| 121 |
Attributes:
|
| 122 |
outscale (int): The output image scale factor (e.g., 2, 4).
|
| 123 |
model_path (str): Path to the pre-trained model weights.
|
| 124 |
+
|
| 125 |
+
Example:
|
| 126 |
+
```py
|
| 127 |
+
from embodied_gen.models.sr_model import ImageRealESRGAN
|
| 128 |
+
from PIL import Image
|
| 129 |
+
|
| 130 |
+
sr_model = ImageRealESRGAN(outscale=4)
|
| 131 |
+
img = Image.open("input.png")
|
| 132 |
+
upscaled = sr_model(img)
|
| 133 |
+
upscaled.save("output.png")
|
| 134 |
+
```
|
| 135 |
"""
|
| 136 |
|
| 137 |
def __init__(self, outscale: int, model_path: str = None) -> None:
|
| 138 |
+
"""Initializes the RealESRGAN upscaler.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
outscale (int): Output scale factor.
|
| 142 |
+
model_path (str, optional): Path to model weights.
|
| 143 |
+
"""
|
| 144 |
# monkey patch to support torchvision>=0.16
|
| 145 |
import torchvision
|
| 146 |
from packaging import version
|
|
|
|
| 174 |
self.model_path = model_path
|
| 175 |
|
| 176 |
def _lazy_init(self):
|
| 177 |
+
"""Lazily initializes the RealESRGAN model."""
|
| 178 |
if self.upsampler is None:
|
| 179 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 180 |
from realesrgan import RealESRGANer
|
|
|
|
| 198 |
|
| 199 |
@spaces.GPU
|
| 200 |
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
|
| 201 |
+
"""Performs super-resolution on the input image.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
image (Union[Image.Image, np.ndarray]): Input image.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Image.Image: Upscaled image.
|
| 208 |
+
"""
|
| 209 |
self._lazy_init()
|
| 210 |
|
| 211 |
if isinstance(image, Image.Image):
|
embodied_gen/models/text_model.py
CHANGED
|
@@ -60,6 +60,11 @@ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background,
|
|
| 60 |
|
| 61 |
|
| 62 |
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
logger.info(f"Download kolors weights from huggingface...")
|
| 64 |
os.makedirs(local_dir, exist_ok=True)
|
| 65 |
subprocess.run(
|
|
@@ -93,6 +98,22 @@ def build_text2img_ip_pipeline(
|
|
| 93 |
ref_scale: float,
|
| 94 |
device: str = "cuda",
|
| 95 |
) -> StableDiffusionXLPipelineIP:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
download_kolors_weights(ckpt_dir)
|
| 97 |
|
| 98 |
text_encoder = ChatGLMModel.from_pretrained(
|
|
@@ -146,6 +167,21 @@ def build_text2img_pipeline(
|
|
| 146 |
ckpt_dir: str,
|
| 147 |
device: str = "cuda",
|
| 148 |
) -> StableDiffusionXLPipeline:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
download_kolors_weights(ckpt_dir)
|
| 150 |
|
| 151 |
text_encoder = ChatGLMModel.from_pretrained(
|
|
@@ -185,6 +221,29 @@ def text2img_gen(
|
|
| 185 |
ip_image_size: int = 512,
|
| 186 |
seed: int = None,
|
| 187 |
) -> list[Image.Image]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
| 189 |
logger.info(f"Processing prompt: {prompt}")
|
| 190 |
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
| 63 |
+
"""Downloads Kolors model weights from HuggingFace.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
local_dir (str, optional): Local directory to store weights.
|
| 67 |
+
"""
|
| 68 |
logger.info(f"Download kolors weights from huggingface...")
|
| 69 |
os.makedirs(local_dir, exist_ok=True)
|
| 70 |
subprocess.run(
|
|
|
|
| 98 |
ref_scale: float,
|
| 99 |
device: str = "cuda",
|
| 100 |
) -> StableDiffusionXLPipelineIP:
|
| 101 |
+
"""Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
ckpt_dir (str): Directory containing model checkpoints.
|
| 105 |
+
ref_scale (float): Reference scale for IP-Adapter.
|
| 106 |
+
device (str, optional): Device for inference.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
StableDiffusionXLPipelineIP: Configured pipeline.
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
```py
|
| 113 |
+
from embodied_gen.models.text_model import build_text2img_ip_pipeline
|
| 114 |
+
pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
|
| 115 |
+
```
|
| 116 |
+
"""
|
| 117 |
download_kolors_weights(ckpt_dir)
|
| 118 |
|
| 119 |
text_encoder = ChatGLMModel.from_pretrained(
|
|
|
|
| 167 |
ckpt_dir: str,
|
| 168 |
device: str = "cuda",
|
| 169 |
) -> StableDiffusionXLPipeline:
|
| 170 |
+
"""Builds a Stable Diffusion XL pipeline for text-to-image generation.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
ckpt_dir (str): Directory containing model checkpoints.
|
| 174 |
+
device (str, optional): Device for inference.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
StableDiffusionXLPipeline: Configured pipeline.
|
| 178 |
+
|
| 179 |
+
Example:
|
| 180 |
+
```py
|
| 181 |
+
from embodied_gen.models.text_model import build_text2img_pipeline
|
| 182 |
+
pipe = build_text2img_pipeline("weights/Kolors")
|
| 183 |
+
```
|
| 184 |
+
"""
|
| 185 |
download_kolors_weights(ckpt_dir)
|
| 186 |
|
| 187 |
text_encoder = ChatGLMModel.from_pretrained(
|
|
|
|
| 221 |
ip_image_size: int = 512,
|
| 222 |
seed: int = None,
|
| 223 |
) -> list[Image.Image]:
|
| 224 |
+
"""Generates images from text prompts using a Stable Diffusion XL pipeline.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
prompt (str): Text prompt for image generation.
|
| 228 |
+
n_sample (int): Number of images to generate.
|
| 229 |
+
guidance_scale (float): Guidance scale for diffusion.
|
| 230 |
+
pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance.
|
| 231 |
+
ip_image (Image.Image | str, optional): Reference image for IP-Adapter.
|
| 232 |
+
image_wh (tuple[int, int], optional): Output image size (width, height).
|
| 233 |
+
infer_step (int, optional): Number of inference steps.
|
| 234 |
+
ip_image_size (int, optional): Size for IP-Adapter image.
|
| 235 |
+
seed (int, optional): Random seed.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
list[Image.Image]: List of generated images.
|
| 239 |
+
|
| 240 |
+
Example:
|
| 241 |
+
```py
|
| 242 |
+
from embodied_gen.models.text_model import text2img_gen
|
| 243 |
+
images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
|
| 244 |
+
images[0].save("banana.png")
|
| 245 |
+
```
|
| 246 |
+
"""
|
| 247 |
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
| 248 |
logger.info(f"Processing prompt: {prompt}")
|
| 249 |
|
embodied_gen/models/texture_model.py
CHANGED
|
@@ -42,6 +42,56 @@ def build_texture_gen_pipe(
|
|
| 42 |
ip_adapt_scale: float = 0,
|
| 43 |
device: str = "cuda",
|
| 44 |
) -> DiffusionPipeline:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
download_kolors_weights(f"{base_ckpt_dir}/Kolors")
|
| 46 |
logger.info(f"Load Kolors weights...")
|
| 47 |
tokenizer = ChatGLMTokenizer.from_pretrained(
|
|
|
|
| 42 |
ip_adapt_scale: float = 0,
|
| 43 |
device: str = "cuda",
|
| 44 |
) -> DiffusionPipeline:
|
| 45 |
+
"""Build and initialize the Kolors + ControlNet (optional IP-Adapter) texture generation pipeline.
|
| 46 |
+
|
| 47 |
+
Loads Kolors tokenizer, text encoder (ChatGLM), VAE, UNet, scheduler and (optionally)
|
| 48 |
+
a ControlNet checkpoint plus IP-Adapter vision encoder. If ``controlnet_ckpt`` is
|
| 49 |
+
not provided, the default multi-view texture ControlNet weights are downloaded
|
| 50 |
+
automatically from the hub. When ``ip_adapt_scale > 0`` an IP-Adapter vision
|
| 51 |
+
encoder and its weights are also loaded and activated.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
base_ckpt_dir (str):
|
| 55 |
+
Root directory where Kolors (and optionally Kolors-IP-Adapter-Plus) weights
|
| 56 |
+
are or will be stored. Required subfolders: ``Kolors/{text_encoder,vae,unet,scheduler}``.
|
| 57 |
+
controlnet_ckpt (str, optional):
|
| 58 |
+
Directory containing a ControlNet checkpoint (safetensors). If ``None``,
|
| 59 |
+
downloads the default ``texture_gen_mv_v1`` snapshot.
|
| 60 |
+
ip_adapt_scale (float, optional):
|
| 61 |
+
Strength (>=0) of IP-Adapter conditioning. Set >0 to enable IP-Adapter;
|
| 62 |
+
typical values: 0.4-0.8. Default: 0 (disabled).
|
| 63 |
+
device (str, optional):
|
| 64 |
+
Target device to move the pipeline to (e.g. ``"cuda"``, ``"cuda:0"``, ``"cpu"``).
|
| 65 |
+
Default: ``"cuda"``.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
DiffusionPipeline: A configured
|
| 69 |
+
``StableDiffusionXLControlNetImg2ImgPipeline`` ready for multi-view texture
|
| 70 |
+
generation (with optional IP-Adapter support).
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
Initialize pipeline with IP-Adapter enabled.
|
| 74 |
+
```python
|
| 75 |
+
from embodied_gen.models.texture_model import build_texture_gen_pipe
|
| 76 |
+
ip_adapt_scale = 0.7
|
| 77 |
+
PIPELINE = build_texture_gen_pipe(
|
| 78 |
+
base_ckpt_dir="./weights",
|
| 79 |
+
ip_adapt_scale=ip_adapt_scale,
|
| 80 |
+
device="cuda",
|
| 81 |
+
)
|
| 82 |
+
PIPELINE.set_ip_adapter_scale([ip_adapt_scale])
|
| 83 |
+
```
|
| 84 |
+
Initialize pipeline without IP-Adapter.
|
| 85 |
+
```python
|
| 86 |
+
from embodied_gen.models.texture_model import build_texture_gen_pipe
|
| 87 |
+
PIPELINE = build_texture_gen_pipe(
|
| 88 |
+
base_ckpt_dir="./weights",
|
| 89 |
+
ip_adapt_scale=0,
|
| 90 |
+
device="cuda",
|
| 91 |
+
)
|
| 92 |
+
```
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
download_kolors_weights(f"{base_ckpt_dir}/Kolors")
|
| 96 |
logger.info(f"Load Kolors weights...")
|
| 97 |
tokenizer = ChatGLMTokenizer.from_pretrained(
|
embodied_gen/scripts/render_gs.py
CHANGED
|
@@ -29,7 +29,7 @@ from embodied_gen.data.utils import (
|
|
| 29 |
init_kal_camera,
|
| 30 |
normalize_vertices_array,
|
| 31 |
)
|
| 32 |
-
from embodied_gen.models.gs_model import
|
| 33 |
from embodied_gen.utils.process_media import combine_images_to_grid
|
| 34 |
|
| 35 |
logging.basicConfig(
|
|
@@ -97,21 +97,6 @@ def parse_args():
|
|
| 97 |
return args
|
| 98 |
|
| 99 |
|
| 100 |
-
def load_gs_model(
|
| 101 |
-
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
|
| 102 |
-
) -> GaussianOperator:
|
| 103 |
-
gs_model = GaussianOperator.load_from_ply(input_gs)
|
| 104 |
-
# Normalize vertices to [-1, 1], center to (0, 0, 0).
|
| 105 |
-
_, scale, center = normalize_vertices_array(gs_model._means)
|
| 106 |
-
scale, center = float(scale), center.tolist()
|
| 107 |
-
transpose = [*[v for v in center], *pre_quat]
|
| 108 |
-
instance_pose = torch.tensor(transpose).to(gs_model.device)
|
| 109 |
-
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
|
| 110 |
-
gs_model.rescale(scale)
|
| 111 |
-
|
| 112 |
-
return gs_model
|
| 113 |
-
|
| 114 |
-
|
| 115 |
@spaces.GPU
|
| 116 |
def entrypoint(**kwargs) -> None:
|
| 117 |
args = parse_args()
|
|
|
|
| 29 |
init_kal_camera,
|
| 30 |
normalize_vertices_array,
|
| 31 |
)
|
| 32 |
+
from embodied_gen.models.gs_model import load_gs_model
|
| 33 |
from embodied_gen.utils.process_media import combine_images_to_grid
|
| 34 |
|
| 35 |
logging.basicConfig(
|
|
|
|
| 97 |
return args
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
@spaces.GPU
|
| 101 |
def entrypoint(**kwargs) -> None:
|
| 102 |
args = parse_args()
|
embodied_gen/trainer/pono2mesh_trainer.py
CHANGED
|
@@ -53,26 +53,31 @@ from thirdparty.pano2room.utils.functions import (
|
|
| 53 |
|
| 54 |
|
| 55 |
class Pano2MeshSRPipeline:
|
| 56 |
-
"""
|
| 57 |
|
| 58 |
-
This class integrates
|
| 59 |
-
|
| 60 |
-
- Inpainting of missing regions under offsets
|
| 61 |
-
- RGB-D to mesh conversion
|
| 62 |
-
- Multi-view mesh repair
|
| 63 |
-
- 3D Gaussian Splatting (3DGS) dataset generation
|
| 64 |
|
| 65 |
Args:
|
| 66 |
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
| 67 |
|
| 68 |
Example:
|
| 69 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
pipeline = Pano2MeshSRPipeline(config)
|
| 71 |
pipeline(pano_image='example.png', output_dir='./output')
|
| 72 |
```
|
| 73 |
"""
|
| 74 |
|
| 75 |
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
self.cfg = config
|
| 77 |
self.device = config.device
|
| 78 |
|
|
@@ -93,6 +98,7 @@ class Pano2MeshSRPipeline:
|
|
| 93 |
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
| 94 |
|
| 95 |
def init_mesh_params(self) -> None:
|
|
|
|
| 96 |
torch.set_default_device(self.device)
|
| 97 |
self.inpaint_mask = torch.ones(
|
| 98 |
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
|
@@ -103,6 +109,14 @@ class Pano2MeshSRPipeline:
|
|
| 103 |
|
| 104 |
@staticmethod
|
| 105 |
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
with open(filepath, "r") as f:
|
| 107 |
values = [float(num) for line in f for num in line.split()]
|
| 108 |
|
|
@@ -111,6 +125,14 @@ class Pano2MeshSRPipeline:
|
|
| 111 |
def load_camera_poses(
|
| 112 |
self, trajectory_dir: str
|
| 113 |
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
pose_filenames = sorted(
|
| 115 |
[
|
| 116 |
fname
|
|
@@ -148,6 +170,14 @@ class Pano2MeshSRPipeline:
|
|
| 148 |
def load_inpaint_poses(
|
| 149 |
self, poses: torch.Tensor
|
| 150 |
) -> dict[int, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
inpaint_poses = dict()
|
| 152 |
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
| 153 |
init_pose = torch.eye(4)
|
|
@@ -162,6 +192,14 @@ class Pano2MeshSRPipeline:
|
|
| 162 |
return inpaint_poses
|
| 163 |
|
| 164 |
def project(self, world_to_cam: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
(
|
| 166 |
project_image,
|
| 167 |
project_depth,
|
|
@@ -185,6 +223,14 @@ class Pano2MeshSRPipeline:
|
|
| 185 |
return project_image[:3, ...], inpaint_mask, project_depth
|
| 186 |
|
| 187 |
def render_pano(self, pose: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
cubemap_list = []
|
| 189 |
for cubemap_pose in self.cubemap_w2cs:
|
| 190 |
project_pose = cubemap_pose @ pose
|
|
@@ -213,6 +259,15 @@ class Pano2MeshSRPipeline:
|
|
| 213 |
world_to_cam: torch.Tensor = None,
|
| 214 |
using_distance_map: bool = True,
|
| 215 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
if world_to_cam is None:
|
| 217 |
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
| 218 |
|
|
@@ -239,6 +294,15 @@ class Pano2MeshSRPipeline:
|
|
| 239 |
def get_edge_image_by_depth(
|
| 240 |
self, depth: torch.Tensor, dilate_iter: int = 1
|
| 241 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
if isinstance(depth, torch.Tensor):
|
| 243 |
depth = depth.cpu().detach().numpy()
|
| 244 |
|
|
@@ -253,6 +317,15 @@ class Pano2MeshSRPipeline:
|
|
| 253 |
def mesh_repair_by_greedy_view_selection(
|
| 254 |
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
| 255 |
) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
inpainted_panos_w_pose = []
|
| 257 |
while len(pose_dict) > 0:
|
| 258 |
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
|
@@ -343,6 +416,17 @@ class Pano2MeshSRPipeline:
|
|
| 343 |
distances: torch.Tensor,
|
| 344 |
pano_mask: torch.Tensor,
|
| 345 |
) -> tuple[torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
mask = (pano_mask[None, ..., None] > 0.5).float()
|
| 347 |
mask = mask.permute(0, 3, 1, 2)
|
| 348 |
mask = dilation(mask, kernel=self.kernel)
|
|
@@ -364,6 +448,14 @@ class Pano2MeshSRPipeline:
|
|
| 364 |
def preprocess_pano(
|
| 365 |
self, image: Image.Image | str
|
| 366 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
if isinstance(image, str):
|
| 368 |
image = Image.open(image)
|
| 369 |
|
|
@@ -387,6 +479,17 @@ class Pano2MeshSRPipeline:
|
|
| 387 |
def pano_to_perpective(
|
| 388 |
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
| 389 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
rots = dict(
|
| 391 |
roll=0,
|
| 392 |
pitch=pitch,
|
|
@@ -404,6 +507,14 @@ class Pano2MeshSRPipeline:
|
|
| 404 |
return perspective
|
| 405 |
|
| 406 |
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
# Define six canonical cube directions in (pitch, yaw)
|
| 408 |
directions = [
|
| 409 |
(0, 0),
|
|
@@ -424,6 +535,11 @@ class Pano2MeshSRPipeline:
|
|
| 424 |
return cubemaps_rgb
|
| 425 |
|
| 426 |
def save_mesh(self, output_path: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
vertices_np = self.vertices.T.cpu().numpy()
|
| 428 |
colors_np = self.colors.T.cpu().numpy()
|
| 429 |
faces_np = self.faces.T.cpu().numpy()
|
|
@@ -434,6 +550,14 @@ class Pano2MeshSRPipeline:
|
|
| 434 |
mesh.export(output_path)
|
| 435 |
|
| 436 |
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
pose = mesh_pose.clone()
|
| 438 |
pose[0, :] *= -1
|
| 439 |
pose[1, :] *= -1
|
|
@@ -450,6 +574,15 @@ class Pano2MeshSRPipeline:
|
|
| 450 |
return c2w
|
| 451 |
|
| 452 |
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
self.init_mesh_params()
|
| 454 |
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
| 455 |
self.sup_pool = SupInfoPool()
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
class Pano2MeshSRPipeline:
|
| 56 |
+
"""Pipeline for converting panoramic RGB images into 3D mesh representations.
|
| 57 |
|
| 58 |
+
This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair,
|
| 59 |
+
and 3D Gaussian Splatting (3DGS) dataset generation.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
Args:
|
| 62 |
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
| 63 |
|
| 64 |
Example:
|
| 65 |
+
```py
|
| 66 |
+
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
| 67 |
+
from embodied_gen.utils.config import Pano2MeshSRConfig
|
| 68 |
+
|
| 69 |
+
config = Pano2MeshSRConfig()
|
| 70 |
pipeline = Pano2MeshSRPipeline(config)
|
| 71 |
pipeline(pano_image='example.png', output_dir='./output')
|
| 72 |
```
|
| 73 |
"""
|
| 74 |
|
| 75 |
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
| 76 |
+
"""Initializes the pipeline with models and camera poses.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
config (Pano2MeshSRConfig): Configuration object.
|
| 80 |
+
"""
|
| 81 |
self.cfg = config
|
| 82 |
self.device = config.device
|
| 83 |
|
|
|
|
| 98 |
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
| 99 |
|
| 100 |
def init_mesh_params(self) -> None:
|
| 101 |
+
"""Initializes mesh parameters and inpaint mask."""
|
| 102 |
torch.set_default_device(self.device)
|
| 103 |
self.inpaint_mask = torch.ones(
|
| 104 |
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
|
|
|
| 109 |
|
| 110 |
@staticmethod
|
| 111 |
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
| 112 |
+
"""Reads a camera pose file and returns the pose matrix.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
filepath (str): Path to the camera pose file.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
np.ndarray: 4x4 camera pose matrix.
|
| 119 |
+
"""
|
| 120 |
with open(filepath, "r") as f:
|
| 121 |
values = [float(num) for line in f for num in line.split()]
|
| 122 |
|
|
|
|
| 125 |
def load_camera_poses(
|
| 126 |
self, trajectory_dir: str
|
| 127 |
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
| 128 |
+
"""Loads camera poses from a directory.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
trajectory_dir (str): Directory containing camera pose files.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses.
|
| 135 |
+
"""
|
| 136 |
pose_filenames = sorted(
|
| 137 |
[
|
| 138 |
fname
|
|
|
|
| 170 |
def load_inpaint_poses(
|
| 171 |
self, poses: torch.Tensor
|
| 172 |
) -> dict[int, torch.Tensor]:
|
| 173 |
+
"""Samples and loads poses for inpainting.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
poses (torch.Tensor): Tensor of camera poses.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors.
|
| 180 |
+
"""
|
| 181 |
inpaint_poses = dict()
|
| 182 |
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
| 183 |
init_pose = torch.eye(4)
|
|
|
|
| 192 |
return inpaint_poses
|
| 193 |
|
| 194 |
def project(self, world_to_cam: torch.Tensor):
|
| 195 |
+
"""Projects the mesh to an image using the given camera pose.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
world_to_cam (torch.Tensor): World-to-camera transformation matrix.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map.
|
| 202 |
+
"""
|
| 203 |
(
|
| 204 |
project_image,
|
| 205 |
project_depth,
|
|
|
|
| 223 |
return project_image[:3, ...], inpaint_mask, project_depth
|
| 224 |
|
| 225 |
def render_pano(self, pose: torch.Tensor):
|
| 226 |
+
"""Renders a panorama from the mesh using the given pose.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
pose (torch.Tensor): Camera pose.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask.
|
| 233 |
+
"""
|
| 234 |
cubemap_list = []
|
| 235 |
for cubemap_pose in self.cubemap_w2cs:
|
| 236 |
project_pose = cubemap_pose @ pose
|
|
|
|
| 259 |
world_to_cam: torch.Tensor = None,
|
| 260 |
using_distance_map: bool = True,
|
| 261 |
) -> None:
|
| 262 |
+
"""Converts RGB-D images to mesh and updates mesh parameters.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
rgb (torch.Tensor): RGB image tensor.
|
| 266 |
+
depth (torch.Tensor): Depth map tensor.
|
| 267 |
+
inpaint_mask (torch.Tensor): Inpaint mask tensor.
|
| 268 |
+
world_to_cam (torch.Tensor, optional): Camera pose.
|
| 269 |
+
using_distance_map (bool, optional): Whether to use distance map.
|
| 270 |
+
"""
|
| 271 |
if world_to_cam is None:
|
| 272 |
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
| 273 |
|
|
|
|
| 294 |
def get_edge_image_by_depth(
|
| 295 |
self, depth: torch.Tensor, dilate_iter: int = 1
|
| 296 |
) -> np.ndarray:
|
| 297 |
+
"""Computes edge image from depth map.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
depth (torch.Tensor): Depth map tensor.
|
| 301 |
+
dilate_iter (int, optional): Number of dilation iterations.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
np.ndarray: Edge image.
|
| 305 |
+
"""
|
| 306 |
if isinstance(depth, torch.Tensor):
|
| 307 |
depth = depth.cpu().detach().numpy()
|
| 308 |
|
|
|
|
| 317 |
def mesh_repair_by_greedy_view_selection(
|
| 318 |
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
| 319 |
) -> list:
|
| 320 |
+
"""Repairs mesh by selecting views greedily and inpainting missing regions.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting.
|
| 324 |
+
output_dir (str): Directory to save visualizations.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
list: List of inpainted panoramas with poses.
|
| 328 |
+
"""
|
| 329 |
inpainted_panos_w_pose = []
|
| 330 |
while len(pose_dict) > 0:
|
| 331 |
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
|
|
|
| 416 |
distances: torch.Tensor,
|
| 417 |
pano_mask: torch.Tensor,
|
| 418 |
) -> tuple[torch.Tensor]:
|
| 419 |
+
"""Inpaints missing regions in a panorama.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
idx (int): Index of the panorama.
|
| 423 |
+
colors (torch.Tensor): RGB image tensor.
|
| 424 |
+
distances (torch.Tensor): Distance map tensor.
|
| 425 |
+
pano_mask (torch.Tensor): Mask tensor.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
tuple[torch.Tensor]: Inpainted RGB image, distances, and normals.
|
| 429 |
+
"""
|
| 430 |
mask = (pano_mask[None, ..., None] > 0.5).float()
|
| 431 |
mask = mask.permute(0, 3, 1, 2)
|
| 432 |
mask = dilation(mask, kernel=self.kernel)
|
|
|
|
| 448 |
def preprocess_pano(
|
| 449 |
self, image: Image.Image | str
|
| 450 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 451 |
+
"""Preprocesses a panoramic image for mesh generation.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
image (Image.Image | str): Input image or path.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors.
|
| 458 |
+
"""
|
| 459 |
if isinstance(image, str):
|
| 460 |
image = Image.open(image)
|
| 461 |
|
|
|
|
| 479 |
def pano_to_perpective(
|
| 480 |
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
| 481 |
) -> torch.Tensor:
|
| 482 |
+
"""Converts a panoramic image to a perspective view.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
pano_image (torch.Tensor): Panoramic image tensor.
|
| 486 |
+
pitch (float): Pitch angle.
|
| 487 |
+
yaw (float): Yaw angle.
|
| 488 |
+
fov (float): Field of view.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
torch.Tensor: Perspective image tensor.
|
| 492 |
+
"""
|
| 493 |
rots = dict(
|
| 494 |
roll=0,
|
| 495 |
pitch=pitch,
|
|
|
|
| 507 |
return perspective
|
| 508 |
|
| 509 |
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
| 510 |
+
"""Converts a panoramic RGB image to six cubemap views.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
pano_rgb (torch.Tensor): Panoramic RGB image tensor.
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
list: List of cubemap RGB tensors.
|
| 517 |
+
"""
|
| 518 |
# Define six canonical cube directions in (pitch, yaw)
|
| 519 |
directions = [
|
| 520 |
(0, 0),
|
|
|
|
| 535 |
return cubemaps_rgb
|
| 536 |
|
| 537 |
def save_mesh(self, output_path: str) -> None:
|
| 538 |
+
"""Saves the mesh to a file.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
output_path (str): Path to save the mesh file.
|
| 542 |
+
"""
|
| 543 |
vertices_np = self.vertices.T.cpu().numpy()
|
| 544 |
colors_np = self.colors.T.cpu().numpy()
|
| 545 |
faces_np = self.faces.T.cpu().numpy()
|
|
|
|
| 550 |
mesh.export(output_path)
|
| 551 |
|
| 552 |
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
| 553 |
+
"""Converts mesh pose to 3D Gaussian Splatting pose.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
mesh_pose (torch.Tensor): Mesh pose tensor.
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
np.ndarray: Converted pose matrix.
|
| 560 |
+
"""
|
| 561 |
pose = mesh_pose.clone()
|
| 562 |
pose[0, :] *= -1
|
| 563 |
pose[1, :] *= -1
|
|
|
|
| 574 |
return c2w
|
| 575 |
|
| 576 |
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
| 577 |
+
"""Runs the pipeline to generate mesh and 3DGS data from a panoramic image.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
pano_image (Image.Image | str): Input panoramic image or path.
|
| 581 |
+
output_dir (str): Directory to save outputs.
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
None
|
| 585 |
+
"""
|
| 586 |
self.init_mesh_params()
|
| 587 |
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
| 588 |
self.sup_pool = SupInfoPool()
|
embodied_gen/utils/enum.py
CHANGED
|
@@ -24,11 +24,27 @@ __all__ = [
|
|
| 24 |
"Scene3DItemEnum",
|
| 25 |
"SpatialRelationEnum",
|
| 26 |
"RobotItemEnum",
|
|
|
|
|
|
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
|
| 30 |
@dataclass
|
| 31 |
class RenderItems(str, Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
IMAGE = "image_color"
|
| 33 |
ALPHA = "image_mask"
|
| 34 |
VIEW_NORMAL = "image_view_normal"
|
|
@@ -41,6 +57,21 @@ class RenderItems(str, Enum):
|
|
| 41 |
|
| 42 |
@dataclass
|
| 43 |
class Scene3DItemEnum(str, Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
BACKGROUND = "background"
|
| 45 |
CONTEXT = "context"
|
| 46 |
ROBOT = "robot"
|
|
@@ -50,6 +81,14 @@ class Scene3DItemEnum(str, Enum):
|
|
| 50 |
|
| 51 |
@classmethod
|
| 52 |
def object_list(cls, layout_relation: dict) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return (
|
| 54 |
[
|
| 55 |
layout_relation[cls.BACKGROUND.value],
|
|
@@ -61,6 +100,14 @@ class Scene3DItemEnum(str, Enum):
|
|
| 61 |
|
| 62 |
@classmethod
|
| 63 |
def object_mapping(cls, layout_relation):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
relation_mapping = {
|
| 65 |
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
| 66 |
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
|
@@ -84,6 +131,15 @@ class Scene3DItemEnum(str, Enum):
|
|
| 84 |
|
| 85 |
@dataclass
|
| 86 |
class SpatialRelationEnum(str, Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
ON = "ON" # objects on the table
|
| 88 |
IN = "IN" # objects in the room
|
| 89 |
INSIDE = "INSIDE" # objects inside the shelf/rack
|
|
@@ -92,6 +148,14 @@ class SpatialRelationEnum(str, Enum):
|
|
| 92 |
|
| 93 |
@dataclass
|
| 94 |
class RobotItemEnum(str, Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
FRANKA = "franka"
|
| 96 |
UR5 = "ur5"
|
| 97 |
PIPER = "piper"
|
|
@@ -99,6 +163,18 @@ class RobotItemEnum(str, Enum):
|
|
| 99 |
|
| 100 |
@dataclass
|
| 101 |
class LayoutInfo(DataClassJsonMixin):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
tree: dict[str, list]
|
| 103 |
relation: dict[str, str | list[str]]
|
| 104 |
objs_desc: dict[str, str] = field(default_factory=dict)
|
|
@@ -106,3 +182,64 @@ class LayoutInfo(DataClassJsonMixin):
|
|
| 106 |
assets: dict[str, str] = field(default_factory=dict)
|
| 107 |
quality: dict[str, str] = field(default_factory=dict)
|
| 108 |
position: dict[str, list[float]] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"Scene3DItemEnum",
|
| 25 |
"SpatialRelationEnum",
|
| 26 |
"RobotItemEnum",
|
| 27 |
+
"LayoutInfo",
|
| 28 |
+
"AssetType",
|
| 29 |
+
"SimAssetMapper",
|
| 30 |
]
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass
|
| 34 |
class RenderItems(str, Enum):
|
| 35 |
+
"""Enumeration of render item types for 3D scenes.
|
| 36 |
+
|
| 37 |
+
Attributes:
|
| 38 |
+
IMAGE: Color image.
|
| 39 |
+
ALPHA: Mask image.
|
| 40 |
+
VIEW_NORMAL: View-space normal image.
|
| 41 |
+
GLOBAL_NORMAL: World-space normal image.
|
| 42 |
+
POSITION_MAP: Position map image.
|
| 43 |
+
DEPTH: Depth image.
|
| 44 |
+
ALBEDO: Albedo image.
|
| 45 |
+
DIFFUSE: Diffuse image.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
IMAGE = "image_color"
|
| 49 |
ALPHA = "image_mask"
|
| 50 |
VIEW_NORMAL = "image_view_normal"
|
|
|
|
| 57 |
|
| 58 |
@dataclass
|
| 59 |
class Scene3DItemEnum(str, Enum):
|
| 60 |
+
"""Enumeration of 3D scene item categories.
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
BACKGROUND: Background objects.
|
| 64 |
+
CONTEXT: Contextual objects.
|
| 65 |
+
ROBOT: Robot entity.
|
| 66 |
+
MANIPULATED_OBJS: Objects manipulated by the robot.
|
| 67 |
+
DISTRACTOR_OBJS: Distractor objects.
|
| 68 |
+
OTHERS: Other objects.
|
| 69 |
+
|
| 70 |
+
Methods:
|
| 71 |
+
object_list(layout_relation): Returns a list of objects in the scene.
|
| 72 |
+
object_mapping(layout_relation): Returns a mapping from object to category.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
BACKGROUND = "background"
|
| 76 |
CONTEXT = "context"
|
| 77 |
ROBOT = "robot"
|
|
|
|
| 81 |
|
| 82 |
@classmethod
|
| 83 |
def object_list(cls, layout_relation: dict) -> list:
|
| 84 |
+
"""Returns a list of objects in the scene.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
layout_relation: Dictionary mapping categories to objects.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of objects in the scene.
|
| 91 |
+
"""
|
| 92 |
return (
|
| 93 |
[
|
| 94 |
layout_relation[cls.BACKGROUND.value],
|
|
|
|
| 100 |
|
| 101 |
@classmethod
|
| 102 |
def object_mapping(cls, layout_relation):
|
| 103 |
+
"""Returns a mapping from object to category.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
layout_relation: Dictionary mapping categories to objects.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Dictionary mapping object names to their category.
|
| 110 |
+
"""
|
| 111 |
relation_mapping = {
|
| 112 |
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
| 113 |
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
|
|
|
| 131 |
|
| 132 |
@dataclass
|
| 133 |
class SpatialRelationEnum(str, Enum):
|
| 134 |
+
"""Enumeration of spatial relations for objects in a scene.
|
| 135 |
+
|
| 136 |
+
Attributes:
|
| 137 |
+
ON: Objects on a surface (e.g., table).
|
| 138 |
+
IN: Objects in a container or room.
|
| 139 |
+
INSIDE: Objects inside a shelf or rack.
|
| 140 |
+
FLOOR: Objects on the floor.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
ON = "ON" # objects on the table
|
| 144 |
IN = "IN" # objects in the room
|
| 145 |
INSIDE = "INSIDE" # objects inside the shelf/rack
|
|
|
|
| 148 |
|
| 149 |
@dataclass
|
| 150 |
class RobotItemEnum(str, Enum):
|
| 151 |
+
"""Enumeration of supported robot types.
|
| 152 |
+
|
| 153 |
+
Attributes:
|
| 154 |
+
FRANKA: Franka robot.
|
| 155 |
+
UR5: UR5 robot.
|
| 156 |
+
PIPER: Piper robot.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
FRANKA = "franka"
|
| 160 |
UR5 = "ur5"
|
| 161 |
PIPER = "piper"
|
|
|
|
| 163 |
|
| 164 |
@dataclass
|
| 165 |
class LayoutInfo(DataClassJsonMixin):
|
| 166 |
+
"""Data structure for layout information in a 3D scene.
|
| 167 |
+
|
| 168 |
+
Attributes:
|
| 169 |
+
tree: Hierarchical structure of scene objects.
|
| 170 |
+
relation: Spatial relations between objects.
|
| 171 |
+
objs_desc: Descriptions of objects.
|
| 172 |
+
objs_mapping: Mapping from object names to categories.
|
| 173 |
+
assets: Asset file paths for objects.
|
| 174 |
+
quality: Quality information for assets.
|
| 175 |
+
position: Position coordinates for objects.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
tree: dict[str, list]
|
| 179 |
relation: dict[str, str | list[str]]
|
| 180 |
objs_desc: dict[str, str] = field(default_factory=dict)
|
|
|
|
| 182 |
assets: dict[str, str] = field(default_factory=dict)
|
| 183 |
quality: dict[str, str] = field(default_factory=dict)
|
| 184 |
position: dict[str, list[float]] = field(default_factory=dict)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@dataclass
|
| 188 |
+
class AssetType(str):
|
| 189 |
+
"""Enumeration for asset types.
|
| 190 |
+
|
| 191 |
+
Supported types:
|
| 192 |
+
MJCF: MuJoCo XML format.
|
| 193 |
+
USD: Universal Scene Description format.
|
| 194 |
+
URDF: Unified Robot Description Format.
|
| 195 |
+
MESH: Mesh file format.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
MJCF = "mjcf"
|
| 199 |
+
USD = "usd"
|
| 200 |
+
URDF = "urdf"
|
| 201 |
+
MESH = "mesh"
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class SimAssetMapper:
|
| 205 |
+
"""Maps simulator names to asset types.
|
| 206 |
+
|
| 207 |
+
Provides a mapping from simulator names to their corresponding asset type.
|
| 208 |
+
|
| 209 |
+
Example:
|
| 210 |
+
```py
|
| 211 |
+
from embodied_gen.utils.enum import SimAssetMapper
|
| 212 |
+
asset_type = SimAssetMapper["isaacsim"]
|
| 213 |
+
print(asset_type) # Output: 'usd'
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
Methods:
|
| 217 |
+
__class_getitem__(key): Returns the asset type for a given simulator name.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
_mapping = dict(
|
| 221 |
+
ISAACSIM=AssetType.USD,
|
| 222 |
+
ISAACGYM=AssetType.URDF,
|
| 223 |
+
MUJOCO=AssetType.MJCF,
|
| 224 |
+
GENESIS=AssetType.MJCF,
|
| 225 |
+
SAPIEN=AssetType.URDF,
|
| 226 |
+
PYBULLET=AssetType.URDF,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def __class_getitem__(cls, key: str):
|
| 231 |
+
"""Returns the asset type for a given simulator name.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
key: Name of the simulator.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
AssetType corresponding to the simulator.
|
| 238 |
+
|
| 239 |
+
Raises:
|
| 240 |
+
KeyError: If the simulator name is not recognized.
|
| 241 |
+
"""
|
| 242 |
+
key = key.upper()
|
| 243 |
+
if key.startswith("SAPIEN"):
|
| 244 |
+
key = "SAPIEN"
|
| 245 |
+
return cls._mapping[key]
|
embodied_gen/utils/geometry.py
CHANGED
|
@@ -45,13 +45,13 @@ __all__ = [
|
|
| 45 |
|
| 46 |
|
| 47 |
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
| 48 |
-
"""
|
| 49 |
|
| 50 |
Args:
|
| 51 |
matrix (np.ndarray): 4x4 transformation matrix.
|
| 52 |
|
| 53 |
Returns:
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
x, y, z = matrix[:3, 3]
|
| 57 |
rot_mat = matrix[:3, :3]
|
|
@@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
|
| 62 |
|
| 63 |
|
| 64 |
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
| 65 |
-
"""
|
| 66 |
|
| 67 |
Args:
|
| 68 |
-
|
| 69 |
|
| 70 |
Returns:
|
| 71 |
-
|
| 72 |
"""
|
| 73 |
x, y, z, qx, qy, qz, qw = pose
|
| 74 |
r = R.from_quat([qx, qy, qz, qw])
|
|
@@ -82,6 +82,16 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
|
| 82 |
def compute_xy_bbox(
|
| 83 |
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
| 84 |
) -> list[float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
x_vals = vertices[:, col_x]
|
| 86 |
y_vals = vertices[:, col_y]
|
| 87 |
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
|
@@ -92,6 +102,16 @@ def has_iou_conflict(
|
|
| 92 |
placed_boxes: list[list[float]],
|
| 93 |
iou_threshold: float = 0.0,
|
| 94 |
) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
| 96 |
for min_x, max_x, min_y, max_y in placed_boxes:
|
| 97 |
ix1 = max(new_min_x, min_x)
|
|
@@ -105,7 +125,14 @@ def has_iou_conflict(
|
|
| 105 |
|
| 106 |
|
| 107 |
def with_seed(seed_attr_name: str = "seed"):
|
| 108 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
def decorator(func):
|
| 111 |
@wraps(func)
|
|
@@ -143,6 +170,20 @@ def compute_convex_hull_path(
|
|
| 143 |
y_axis: int = 1,
|
| 144 |
z_axis: int = 2,
|
| 145 |
) -> Path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
top_vertices = vertices[
|
| 147 |
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
| 148 |
]
|
|
@@ -170,6 +211,15 @@ def compute_convex_hull_path(
|
|
| 170 |
|
| 171 |
|
| 172 |
def find_parent_node(node: str, tree: dict) -> str | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
for parent, children in tree.items():
|
| 174 |
if any(child[0] == node for child in children):
|
| 175 |
return parent
|
|
@@ -177,6 +227,16 @@ def find_parent_node(node: str, tree: dict) -> str | None:
|
|
| 177 |
|
| 178 |
|
| 179 |
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
x1, x2, y1, y2 = box
|
| 181 |
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
| 182 |
|
|
@@ -187,6 +247,15 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
|
| 187 |
def compute_axis_rotation_quat(
|
| 188 |
axis: Literal["x", "y", "z"], angle_rad: float
|
| 189 |
) -> list[float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
if axis.lower() == "x":
|
| 191 |
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
| 192 |
elif axis.lower() == "y":
|
|
@@ -202,6 +271,15 @@ def compute_axis_rotation_quat(
|
|
| 202 |
def quaternion_multiply(
|
| 203 |
init_quat: list[float], rotate_quat: list[float]
|
| 204 |
) -> list[float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
qx, qy, qz, qw = init_quat
|
| 206 |
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
| 207 |
qx, qy, qz, qw = rotate_quat
|
|
@@ -217,7 +295,17 @@ def check_reachable(
|
|
| 217 |
min_reach: float = 0.25,
|
| 218 |
max_reach: float = 0.85,
|
| 219 |
) -> bool:
|
| 220 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
distance = np.linalg.norm(reach_xyz - base_xyz)
|
| 222 |
|
| 223 |
return min_reach < distance < max_reach
|
|
@@ -238,26 +326,31 @@ def bfs_placement(
|
|
| 238 |
robot_dim: float = 0.12,
|
| 239 |
seed: int = None,
|
| 240 |
) -> LayoutInfo:
|
| 241 |
-
"""
|
| 242 |
|
| 243 |
Args:
|
| 244 |
-
layout_file: Path to
|
| 245 |
-
floor_margin: Z-offset for
|
| 246 |
-
beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
| 247 |
-
max_attempts
|
| 248 |
-
init_rpy: Initial
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
seed: Random seed for reproducible placement.
|
| 257 |
|
| 258 |
Returns:
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
"""
|
| 262 |
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
| 263 |
asset_dir = os.path.dirname(layout_file)
|
|
@@ -478,6 +571,13 @@ def bfs_placement(
|
|
| 478 |
def compose_mesh_scene(
|
| 479 |
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
| 480 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
| 482 |
scene = trimesh.Scene()
|
| 483 |
for node in layout_info.assets:
|
|
@@ -505,6 +605,16 @@ def compose_mesh_scene(
|
|
| 505 |
def compute_pinhole_intrinsics(
|
| 506 |
image_w: int, image_h: int, fov_deg: float
|
| 507 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
fov_rad = np.deg2rad(fov_deg)
|
| 509 |
fx = image_w / (2 * np.tan(fov_rad / 2))
|
| 510 |
fy = fx # assuming square pixels
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def matrix_to_pose(matrix: np.ndarray) -> list[float]:
|
| 48 |
+
"""Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
|
| 49 |
|
| 50 |
Args:
|
| 51 |
matrix (np.ndarray): 4x4 transformation matrix.
|
| 52 |
|
| 53 |
Returns:
|
| 54 |
+
list[float]: Pose as [x, y, z, qx, qy, qz, qw].
|
| 55 |
"""
|
| 56 |
x, y, z = matrix[:3, 3]
|
| 57 |
rot_mat = matrix[:3, :3]
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
def pose_to_matrix(pose: list[float]) -> np.ndarray:
|
| 65 |
+
"""Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
|
| 66 |
|
| 67 |
Args:
|
| 68 |
+
pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw].
|
| 69 |
|
| 70 |
Returns:
|
| 71 |
+
np.ndarray: 4x4 transformation matrix.
|
| 72 |
"""
|
| 73 |
x, y, z, qx, qy, qz, qw = pose
|
| 74 |
r = R.from_quat([qx, qy, qz, qw])
|
|
|
|
| 82 |
def compute_xy_bbox(
|
| 83 |
vertices: np.ndarray, col_x: int = 0, col_y: int = 1
|
| 84 |
) -> list[float]:
|
| 85 |
+
"""Computes the bounding box in XY plane for given vertices.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
vertices (np.ndarray): Vertex coordinates.
|
| 89 |
+
col_x (int, optional): Column index for X.
|
| 90 |
+
col_y (int, optional): Column index for Y.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
list[float]: [min_x, max_x, min_y, max_y]
|
| 94 |
+
"""
|
| 95 |
x_vals = vertices[:, col_x]
|
| 96 |
y_vals = vertices[:, col_y]
|
| 97 |
return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
|
|
|
|
| 102 |
placed_boxes: list[list[float]],
|
| 103 |
iou_threshold: float = 0.0,
|
| 104 |
) -> bool:
|
| 105 |
+
"""Checks for intersection-over-union conflict between boxes.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
new_box (list[float]): New box coordinates.
|
| 109 |
+
placed_boxes (list[list[float]]): List of placed box coordinates.
|
| 110 |
+
iou_threshold (float, optional): IOU threshold.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
bool: True if conflict exists, False otherwise.
|
| 114 |
+
"""
|
| 115 |
new_min_x, new_max_x, new_min_y, new_max_y = new_box
|
| 116 |
for min_x, max_x, min_y, max_y in placed_boxes:
|
| 117 |
ix1 = max(new_min_x, min_x)
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def with_seed(seed_attr_name: str = "seed"):
|
| 128 |
+
"""Decorator to temporarily set the random seed for reproducibility.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
seed_attr_name (str, optional): Name of the seed argument.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
function: Decorator function.
|
| 135 |
+
"""
|
| 136 |
|
| 137 |
def decorator(func):
|
| 138 |
@wraps(func)
|
|
|
|
| 170 |
y_axis: int = 1,
|
| 171 |
z_axis: int = 2,
|
| 172 |
) -> Path:
|
| 173 |
+
"""Computes a dense convex hull path for the top surface of a mesh.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
vertices (np.ndarray): Mesh vertices.
|
| 177 |
+
z_threshold (float, optional): Z threshold for top surface.
|
| 178 |
+
interp_per_edge (int, optional): Interpolation points per edge.
|
| 179 |
+
margin (float, optional): Margin for polygon buffer.
|
| 180 |
+
x_axis (int, optional): X axis index.
|
| 181 |
+
y_axis (int, optional): Y axis index.
|
| 182 |
+
z_axis (int, optional): Z axis index.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Path: Matplotlib path object for the convex hull.
|
| 186 |
+
"""
|
| 187 |
top_vertices = vertices[
|
| 188 |
vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
|
| 189 |
]
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
def find_parent_node(node: str, tree: dict) -> str | None:
|
| 214 |
+
"""Finds the parent node of a given node in a tree.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
node (str): Node name.
|
| 218 |
+
tree (dict): Tree structure.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
str | None: Parent node name or None.
|
| 222 |
+
"""
|
| 223 |
for parent, children in tree.items():
|
| 224 |
if any(child[0] == node for child in children):
|
| 225 |
return parent
|
|
|
|
| 227 |
|
| 228 |
|
| 229 |
def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
|
| 230 |
+
"""Checks if at least `threshold` corners of a box are inside a hull.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
hull (Path): Convex hull path.
|
| 234 |
+
box (list): Box coordinates [x1, x2, y1, y2].
|
| 235 |
+
threshold (int, optional): Minimum corners inside.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
bool: True if enough corners are inside.
|
| 239 |
+
"""
|
| 240 |
x1, x2, y1, y2 = box
|
| 241 |
corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
|
| 242 |
|
|
|
|
| 247 |
def compute_axis_rotation_quat(
|
| 248 |
axis: Literal["x", "y", "z"], angle_rad: float
|
| 249 |
) -> list[float]:
|
| 250 |
+
"""Computes quaternion for rotation around a given axis.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
axis (Literal["x", "y", "z"]): Axis of rotation.
|
| 254 |
+
angle_rad (float): Rotation angle in radians.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
list[float]: Quaternion [x, y, z, w].
|
| 258 |
+
"""
|
| 259 |
if axis.lower() == "x":
|
| 260 |
q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
|
| 261 |
elif axis.lower() == "y":
|
|
|
|
| 271 |
def quaternion_multiply(
|
| 272 |
init_quat: list[float], rotate_quat: list[float]
|
| 273 |
) -> list[float]:
|
| 274 |
+
"""Multiplies two quaternions.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
init_quat (list[float]): Initial quaternion [x, y, z, w].
|
| 278 |
+
rotate_quat (list[float]): Rotation quaternion [x, y, z, w].
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
list[float]: Resulting quaternion [x, y, z, w].
|
| 282 |
+
"""
|
| 283 |
qx, qy, qz, qw = init_quat
|
| 284 |
q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
|
| 285 |
qx, qy, qz, qw = rotate_quat
|
|
|
|
| 295 |
min_reach: float = 0.25,
|
| 296 |
max_reach: float = 0.85,
|
| 297 |
) -> bool:
|
| 298 |
+
"""Checks if the target point is within the reachable range.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
base_xyz (np.ndarray): Base position.
|
| 302 |
+
reach_xyz (np.ndarray): Target position.
|
| 303 |
+
min_reach (float, optional): Minimum reach distance.
|
| 304 |
+
max_reach (float, optional): Maximum reach distance.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
bool: True if reachable, False otherwise.
|
| 308 |
+
"""
|
| 309 |
distance = np.linalg.norm(reach_xyz - base_xyz)
|
| 310 |
|
| 311 |
return min_reach < distance < max_reach
|
|
|
|
| 326 |
robot_dim: float = 0.12,
|
| 327 |
seed: int = None,
|
| 328 |
) -> LayoutInfo:
|
| 329 |
+
"""Places objects in a scene layout using BFS traversal.
|
| 330 |
|
| 331 |
Args:
|
| 332 |
+
layout_file (str): Path to layout JSON file generated from `layout-cli`.
|
| 333 |
+
floor_margin (float, optional): Z-offset for objects placed on the floor.
|
| 334 |
+
beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
|
| 335 |
+
max_attempts (int, optional): Max attempts for a non-overlapping placement.
|
| 336 |
+
init_rpy (tuple, optional): Initial rotation (rpy).
|
| 337 |
+
rotate_objs (bool, optional): Whether to random rotate objects.
|
| 338 |
+
rotate_bg (bool, optional): Whether to random rotate background.
|
| 339 |
+
rotate_context (bool, optional): Whether to random rotate context asset.
|
| 340 |
+
limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
|
| 341 |
+
max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
|
| 342 |
+
robot_dim (float, optional): The approximate robot size.
|
| 343 |
+
seed (int, optional): Random seed for reproducible placement.
|
|
|
|
| 344 |
|
| 345 |
Returns:
|
| 346 |
+
LayoutInfo: Layout information with object poses.
|
| 347 |
+
|
| 348 |
+
Example:
|
| 349 |
+
```py
|
| 350 |
+
from embodied_gen.utils.geometry import bfs_placement
|
| 351 |
+
layout = bfs_placement("scene_layout.json", seed=42)
|
| 352 |
+
print(layout.position)
|
| 353 |
+
```
|
| 354 |
"""
|
| 355 |
layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
|
| 356 |
asset_dir = os.path.dirname(layout_file)
|
|
|
|
| 571 |
def compose_mesh_scene(
|
| 572 |
layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
|
| 573 |
) -> None:
|
| 574 |
+
"""Composes a mesh scene from layout information and saves to file.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
layout_info (LayoutInfo): Layout information.
|
| 578 |
+
out_scene_path (str): Output scene file path.
|
| 579 |
+
with_bg (bool, optional): Include background mesh.
|
| 580 |
+
"""
|
| 581 |
object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
|
| 582 |
scene = trimesh.Scene()
|
| 583 |
for node in layout_info.assets:
|
|
|
|
| 605 |
def compute_pinhole_intrinsics(
|
| 606 |
image_w: int, image_h: int, fov_deg: float
|
| 607 |
) -> np.ndarray:
|
| 608 |
+
"""Computes pinhole camera intrinsic matrix from image size and FOV.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
image_w (int): Image width.
|
| 612 |
+
image_h (int): Image height.
|
| 613 |
+
fov_deg (float): Field of view in degrees.
|
| 614 |
+
|
| 615 |
+
Returns:
|
| 616 |
+
np.ndarray: Intrinsic matrix K.
|
| 617 |
+
"""
|
| 618 |
fov_rad = np.deg2rad(fov_deg)
|
| 619 |
fx = image_w / (2 * np.tan(fov_rad / 2))
|
| 620 |
fy = fx # assuming square pixels
|
embodied_gen/utils/gpt_clients.py
CHANGED
|
@@ -45,7 +45,35 @@ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
|
| 45 |
|
| 46 |
|
| 47 |
class GPTclient:
|
| 48 |
-
"""A client to interact with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def __init__(
|
| 51 |
self,
|
|
@@ -82,6 +110,7 @@ class GPTclient:
|
|
| 82 |
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
| 83 |
)
|
| 84 |
def completion_with_backoff(self, **kwargs):
|
|
|
|
| 85 |
return self.client.chat.completions.create(**kwargs)
|
| 86 |
|
| 87 |
def query(
|
|
@@ -91,19 +120,16 @@ class GPTclient:
|
|
| 91 |
system_role: Optional[str] = None,
|
| 92 |
params: Optional[dict] = None,
|
| 93 |
) -> Optional[str]:
|
| 94 |
-
"""Queries the GPT model with
|
| 95 |
|
| 96 |
Args:
|
| 97 |
-
text_prompt (str):
|
| 98 |
-
image_base64 (Optional[
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
that specify the behavior of the assistant.
|
| 102 |
-
params (Optional[dict]): Additional parameters for GPT setting.
|
| 103 |
|
| 104 |
Returns:
|
| 105 |
-
Optional[str]:
|
| 106 |
-
the prompt. Returns `None` if an error occurs.
|
| 107 |
"""
|
| 108 |
if system_role is None:
|
| 109 |
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
|
@@ -177,7 +203,11 @@ class GPTclient:
|
|
| 177 |
return response
|
| 178 |
|
| 179 |
def check_connection(self) -> None:
|
| 180 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
try:
|
| 182 |
response = self.completion_with_backoff(
|
| 183 |
messages=[
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
class GPTclient:
|
| 48 |
+
"""A client to interact with GPT models via OpenAI or Azure API.
|
| 49 |
+
|
| 50 |
+
Supports text and image prompts, connection checking, and configurable parameters.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
endpoint (str): API endpoint URL.
|
| 54 |
+
api_key (str): API key for authentication.
|
| 55 |
+
model_name (str, optional): Model name to use.
|
| 56 |
+
api_version (str, optional): API version (for Azure).
|
| 57 |
+
check_connection (bool, optional): Whether to check API connection.
|
| 58 |
+
verbose (bool, optional): Enable verbose logging.
|
| 59 |
+
|
| 60 |
+
Example:
|
| 61 |
+
```sh
|
| 62 |
+
export ENDPOINT="https://yfb-openai-sweden.openai.azure.com"
|
| 63 |
+
export API_KEY="xxxxxx"
|
| 64 |
+
export API_VERSION="2025-03-01-preview"
|
| 65 |
+
export MODEL_NAME="yfb-gpt-4o-sweden"
|
| 66 |
+
```
|
| 67 |
+
```py
|
| 68 |
+
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 69 |
+
|
| 70 |
+
response = GPT_CLIENT.query("Describe the physics of a falling apple.")
|
| 71 |
+
response = GPT_CLIENT.query(
|
| 72 |
+
text_prompt="Describe the content in each image."
|
| 73 |
+
image_base64=["path/to/image1.png", "path/to/image2.jpg"],
|
| 74 |
+
)
|
| 75 |
+
```
|
| 76 |
+
"""
|
| 77 |
|
| 78 |
def __init__(
|
| 79 |
self,
|
|
|
|
| 110 |
stop=(stop_after_attempt(10) | stop_after_delay(30)),
|
| 111 |
)
|
| 112 |
def completion_with_backoff(self, **kwargs):
|
| 113 |
+
"""Performs a chat completion request with retry/backoff."""
|
| 114 |
return self.client.chat.completions.create(**kwargs)
|
| 115 |
|
| 116 |
def query(
|
|
|
|
| 120 |
system_role: Optional[str] = None,
|
| 121 |
params: Optional[dict] = None,
|
| 122 |
) -> Optional[str]:
|
| 123 |
+
"""Queries the GPT model with text and optional image prompts.
|
| 124 |
|
| 125 |
Args:
|
| 126 |
+
text_prompt (str): Main text input.
|
| 127 |
+
image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images.
|
| 128 |
+
system_role (Optional[str], optional): System-level instructions.
|
| 129 |
+
params (Optional[dict], optional): Additional GPT parameters.
|
|
|
|
|
|
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
+
Optional[str]: Model response content, or None if error.
|
|
|
|
| 133 |
"""
|
| 134 |
if system_role is None:
|
| 135 |
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
|
|
|
|
| 203 |
return response
|
| 204 |
|
| 205 |
def check_connection(self) -> None:
|
| 206 |
+
"""Checks whether the GPT API connection is working.
|
| 207 |
+
|
| 208 |
+
Raises:
|
| 209 |
+
ConnectionError: If connection fails.
|
| 210 |
+
"""
|
| 211 |
try:
|
| 212 |
response = self.completion_with_backoff(
|
| 213 |
messages=[
|
embodied_gen/utils/process_media.py
CHANGED
|
@@ -69,6 +69,40 @@ def render_asset3d(
|
|
| 69 |
no_index_file: bool = False,
|
| 70 |
with_mtl: bool = True,
|
| 71 |
) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
input_args = dict(
|
| 73 |
mesh_path=mesh_path,
|
| 74 |
output_root=output_root,
|
|
@@ -95,6 +129,13 @@ def render_asset3d(
|
|
| 95 |
|
| 96 |
|
| 97 |
def merge_images_video(color_images, normal_images, output_path) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
width = color_images[0].shape[1]
|
| 99 |
combined_video = [
|
| 100 |
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
|
@@ -108,7 +149,13 @@ def merge_images_video(color_images, normal_images, output_path) -> None:
|
|
| 108 |
def merge_video_video(
|
| 109 |
video_path1: str, video_path2: str, output_path: str
|
| 110 |
) -> None:
|
| 111 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
clip1 = VideoFileClip(video_path1)
|
| 113 |
clip2 = VideoFileClip(video_path2)
|
| 114 |
|
|
@@ -127,6 +174,16 @@ def filter_small_connected_components(
|
|
| 127 |
area_ratio: float,
|
| 128 |
connectivity: int = 8,
|
| 129 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if isinstance(mask, Image.Image):
|
| 131 |
mask = np.array(mask)
|
| 132 |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
|
@@ -152,6 +209,16 @@ def filter_image_small_connected_components(
|
|
| 152 |
area_ratio: float = 10,
|
| 153 |
connectivity: int = 8,
|
| 154 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
if isinstance(image, Image.Image):
|
| 156 |
image = image.convert("RGBA")
|
| 157 |
image = np.array(image)
|
|
@@ -169,6 +236,24 @@ def combine_images_to_grid(
|
|
| 169 |
target_wh: tuple[int, int] = (512, 512),
|
| 170 |
image_mode: str = "RGB",
|
| 171 |
) -> list[Image.Image]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
n_images = len(images)
|
| 173 |
if n_images == 1:
|
| 174 |
return images
|
|
@@ -196,6 +281,19 @@ def combine_images_to_grid(
|
|
| 196 |
|
| 197 |
|
| 198 |
class SceneTreeVisualizer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def __init__(self, layout_info: LayoutInfo) -> None:
|
| 200 |
self.tree = layout_info.tree
|
| 201 |
self.relation = layout_info.relation
|
|
@@ -274,6 +372,14 @@ class SceneTreeVisualizer:
|
|
| 274 |
dpi=300,
|
| 275 |
title: str = "Scene 3D Hierarchy Tree",
|
| 276 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
node_colors = [
|
| 278 |
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
| 279 |
]
|
|
@@ -350,6 +456,14 @@ class SceneTreeVisualizer:
|
|
| 350 |
|
| 351 |
|
| 352 |
def load_scene_dict(file_path: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
scene_dict = {}
|
| 354 |
with open(file_path, "r", encoding='utf-8') as f:
|
| 355 |
for line in f:
|
|
@@ -363,12 +477,28 @@ def load_scene_dict(file_path: str) -> dict:
|
|
| 363 |
|
| 364 |
|
| 365 |
def is_image_file(filename: str) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
mime_type, _ = mimetypes.guess_type(filename)
|
| 367 |
|
| 368 |
return mime_type is not None and mime_type.startswith('image')
|
| 369 |
|
| 370 |
|
| 371 |
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
| 373 |
with open(prompts[0], "r") as f:
|
| 374 |
prompts = [
|
|
@@ -386,13 +516,18 @@ def alpha_blend_rgba(
|
|
| 386 |
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
| 387 |
|
| 388 |
Args:
|
| 389 |
-
fg_image: Foreground image
|
| 390 |
-
|
| 391 |
-
bg_image: Background image. Can be a file path (str), a PIL Image,
|
| 392 |
-
or a NumPy ndarray.
|
| 393 |
|
| 394 |
Returns:
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
"""
|
| 397 |
if isinstance(fg_image, str):
|
| 398 |
fg_image = Image.open(fg_image)
|
|
@@ -421,13 +556,11 @@ def check_object_edge_truncated(
|
|
| 421 |
"""Checks if a binary object mask is truncated at the image edges.
|
| 422 |
|
| 423 |
Args:
|
| 424 |
-
mask:
|
| 425 |
-
edge_threshold
|
| 426 |
-
Defaults to 5.
|
| 427 |
|
| 428 |
Returns:
|
| 429 |
-
True if
|
| 430 |
-
False if the object touches or crosses any image boundary.
|
| 431 |
"""
|
| 432 |
top = mask[:edge_threshold, :].any()
|
| 433 |
bottom = mask[-edge_threshold:, :].any()
|
|
@@ -440,6 +573,22 @@ def check_object_edge_truncated(
|
|
| 440 |
def vcat_pil_images(
|
| 441 |
images: list[Image.Image], image_mode: str = "RGB"
|
| 442 |
) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
widths, heights = zip(*(img.size for img in images))
|
| 444 |
total_height = sum(heights)
|
| 445 |
max_width = max(widths)
|
|
|
|
| 69 |
no_index_file: bool = False,
|
| 70 |
with_mtl: bool = True,
|
| 71 |
) -> list[str]:
|
| 72 |
+
"""Renders a 3D mesh asset and returns output image paths.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
mesh_path (str): Path to the mesh file.
|
| 76 |
+
output_root (str): Directory to save outputs.
|
| 77 |
+
distance (float, optional): Camera distance.
|
| 78 |
+
num_images (int, optional): Number of views to render.
|
| 79 |
+
elevation (list[float], optional): Camera elevation angles.
|
| 80 |
+
pbr_light_factor (float, optional): PBR lighting factor.
|
| 81 |
+
return_key (str, optional): Glob pattern for output images.
|
| 82 |
+
output_subdir (str, optional): Subdirectory for outputs.
|
| 83 |
+
gen_color_mp4 (bool, optional): Generate color MP4 video.
|
| 84 |
+
gen_viewnormal_mp4 (bool, optional): Generate view normal MP4.
|
| 85 |
+
gen_glonormal_mp4 (bool, optional): Generate global normal MP4.
|
| 86 |
+
no_index_file (bool, optional): Skip index file saving.
|
| 87 |
+
with_mtl (bool, optional): Use mesh material.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
list[str]: List of output image file paths.
|
| 91 |
+
|
| 92 |
+
Example:
|
| 93 |
+
```py
|
| 94 |
+
from embodied_gen.utils.process_media import render_asset3d
|
| 95 |
+
|
| 96 |
+
image_paths = render_asset3d(
|
| 97 |
+
mesh_path="path_to_mesh.obj",
|
| 98 |
+
output_root="path_to_save_dir",
|
| 99 |
+
num_images=6,
|
| 100 |
+
elevation=(30, -30),
|
| 101 |
+
output_subdir="renders",
|
| 102 |
+
no_index_file=True,
|
| 103 |
+
)
|
| 104 |
+
```
|
| 105 |
+
"""
|
| 106 |
input_args = dict(
|
| 107 |
mesh_path=mesh_path,
|
| 108 |
output_root=output_root,
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
def merge_images_video(color_images, normal_images, output_path) -> None:
|
| 132 |
+
"""Merges color and normal images into a video.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
color_images (list[np.ndarray]): List of color images.
|
| 136 |
+
normal_images (list[np.ndarray]): List of normal images.
|
| 137 |
+
output_path (str): Path to save the output video.
|
| 138 |
+
"""
|
| 139 |
width = color_images[0].shape[1]
|
| 140 |
combined_video = [
|
| 141 |
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
|
|
|
|
| 149 |
def merge_video_video(
|
| 150 |
video_path1: str, video_path2: str, output_path: str
|
| 151 |
) -> None:
|
| 152 |
+
"""Merges two videos by combining their left and right halves.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
video_path1 (str): Path to first video.
|
| 156 |
+
video_path2 (str): Path to second video.
|
| 157 |
+
output_path (str): Path to save the merged video.
|
| 158 |
+
"""
|
| 159 |
clip1 = VideoFileClip(video_path1)
|
| 160 |
clip2 = VideoFileClip(video_path2)
|
| 161 |
|
|
|
|
| 174 |
area_ratio: float,
|
| 175 |
connectivity: int = 8,
|
| 176 |
) -> np.ndarray:
|
| 177 |
+
"""Removes small connected components from a binary mask.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
mask (Union[Image.Image, np.ndarray]): Input mask.
|
| 181 |
+
area_ratio (float): Minimum area ratio for components.
|
| 182 |
+
connectivity (int, optional): Connectivity for labeling.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
np.ndarray: Mask with small components removed.
|
| 186 |
+
"""
|
| 187 |
if isinstance(mask, Image.Image):
|
| 188 |
mask = np.array(mask)
|
| 189 |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
|
|
|
| 209 |
area_ratio: float = 10,
|
| 210 |
connectivity: int = 8,
|
| 211 |
) -> np.ndarray:
|
| 212 |
+
"""Removes small connected components from the alpha channel of an image.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
image (Union[Image.Image, np.ndarray]): Input image.
|
| 216 |
+
area_ratio (float, optional): Minimum area ratio.
|
| 217 |
+
connectivity (int, optional): Connectivity for labeling.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
np.ndarray: Image with filtered alpha channel.
|
| 221 |
+
"""
|
| 222 |
if isinstance(image, Image.Image):
|
| 223 |
image = image.convert("RGBA")
|
| 224 |
image = np.array(image)
|
|
|
|
| 236 |
target_wh: tuple[int, int] = (512, 512),
|
| 237 |
image_mode: str = "RGB",
|
| 238 |
) -> list[Image.Image]:
|
| 239 |
+
"""Combines multiple images into a grid.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
images (list[str | Image.Image]): List of image paths or PIL Images.
|
| 243 |
+
cat_row_col (tuple[int, int], optional): Grid rows and columns.
|
| 244 |
+
target_wh (tuple[int, int], optional): Target image size.
|
| 245 |
+
image_mode (str, optional): Image mode.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
list[Image.Image]: List containing the grid image.
|
| 249 |
+
|
| 250 |
+
Example:
|
| 251 |
+
```py
|
| 252 |
+
from embodied_gen.utils.process_media import combine_images_to_grid
|
| 253 |
+
grid = combine_images_to_grid(["img1.png", "img2.png"])
|
| 254 |
+
grid[0].save("grid.png")
|
| 255 |
+
```
|
| 256 |
+
"""
|
| 257 |
n_images = len(images)
|
| 258 |
if n_images == 1:
|
| 259 |
return images
|
|
|
|
| 281 |
|
| 282 |
|
| 283 |
class SceneTreeVisualizer:
|
| 284 |
+
"""Visualizes a scene tree layout using networkx and matplotlib.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
layout_info (LayoutInfo): Layout information for the scene.
|
| 288 |
+
|
| 289 |
+
Example:
|
| 290 |
+
```py
|
| 291 |
+
from embodied_gen.utils.process_media import SceneTreeVisualizer
|
| 292 |
+
visualizer = SceneTreeVisualizer(layout_info)
|
| 293 |
+
visualizer.render(save_path="tree.png")
|
| 294 |
+
```
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
def __init__(self, layout_info: LayoutInfo) -> None:
|
| 298 |
self.tree = layout_info.tree
|
| 299 |
self.relation = layout_info.relation
|
|
|
|
| 372 |
dpi=300,
|
| 373 |
title: str = "Scene 3D Hierarchy Tree",
|
| 374 |
):
|
| 375 |
+
"""Renders the scene tree and saves to file.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
save_path (str): Path to save the rendered image.
|
| 379 |
+
figsize (tuple, optional): Figure size.
|
| 380 |
+
dpi (int, optional): Image DPI.
|
| 381 |
+
title (str, optional): Plot image title.
|
| 382 |
+
"""
|
| 383 |
node_colors = [
|
| 384 |
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
| 385 |
]
|
|
|
|
| 456 |
|
| 457 |
|
| 458 |
def load_scene_dict(file_path: str) -> dict:
|
| 459 |
+
"""Loads a scene description dictionary from a file.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
file_path (str): Path to the scene description file.
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
dict: Mapping from scene ID to description.
|
| 466 |
+
"""
|
| 467 |
scene_dict = {}
|
| 468 |
with open(file_path, "r", encoding='utf-8') as f:
|
| 469 |
for line in f:
|
|
|
|
| 477 |
|
| 478 |
|
| 479 |
def is_image_file(filename: str) -> bool:
|
| 480 |
+
"""Checks if a filename is an image file.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
filename (str): Filename to check.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
bool: True if image file, False otherwise.
|
| 487 |
+
"""
|
| 488 |
mime_type, _ = mimetypes.guess_type(filename)
|
| 489 |
|
| 490 |
return mime_type is not None and mime_type.startswith('image')
|
| 491 |
|
| 492 |
|
| 493 |
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
| 494 |
+
"""Parses text prompts from a list or file.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
prompts (list[str]): List of prompts or a file path.
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
list[str]: List of parsed prompts.
|
| 501 |
+
"""
|
| 502 |
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
| 503 |
with open(prompts[0], "r") as f:
|
| 504 |
prompts = [
|
|
|
|
| 516 |
"""Alpha blends a foreground RGBA image over a background RGBA image.
|
| 517 |
|
| 518 |
Args:
|
| 519 |
+
fg_image: Foreground image (str, PIL Image, or ndarray).
|
| 520 |
+
bg_image: Background image (str, PIL Image, or ndarray).
|
|
|
|
|
|
|
| 521 |
|
| 522 |
Returns:
|
| 523 |
+
Image.Image: Alpha-blended RGBA image.
|
| 524 |
+
|
| 525 |
+
Example:
|
| 526 |
+
```py
|
| 527 |
+
from embodied_gen.utils.process_media import alpha_blend_rgba
|
| 528 |
+
result = alpha_blend_rgba("fg.png", "bg.png")
|
| 529 |
+
result.save("blended.png")
|
| 530 |
+
```
|
| 531 |
"""
|
| 532 |
if isinstance(fg_image, str):
|
| 533 |
fg_image = Image.open(fg_image)
|
|
|
|
| 556 |
"""Checks if a binary object mask is truncated at the image edges.
|
| 557 |
|
| 558 |
Args:
|
| 559 |
+
mask (np.ndarray): 2D binary mask.
|
| 560 |
+
edge_threshold (int, optional): Edge pixel threshold.
|
|
|
|
| 561 |
|
| 562 |
Returns:
|
| 563 |
+
bool: True if object is fully enclosed, False if truncated.
|
|
|
|
| 564 |
"""
|
| 565 |
top = mask[:edge_threshold, :].any()
|
| 566 |
bottom = mask[-edge_threshold:, :].any()
|
|
|
|
| 573 |
def vcat_pil_images(
|
| 574 |
images: list[Image.Image], image_mode: str = "RGB"
|
| 575 |
) -> Image.Image:
|
| 576 |
+
"""Vertically concatenates a list of PIL images.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
images (list[Image.Image]): List of images.
|
| 580 |
+
image_mode (str, optional): Image mode.
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
Image.Image: Vertically concatenated image.
|
| 584 |
+
|
| 585 |
+
Example:
|
| 586 |
+
```py
|
| 587 |
+
from embodied_gen.utils.process_media import vcat_pil_images
|
| 588 |
+
img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")])
|
| 589 |
+
img.save("vcat.png")
|
| 590 |
+
```
|
| 591 |
+
"""
|
| 592 |
widths, heights = zip(*(img.size for img in images))
|
| 593 |
total_height = sum(heights)
|
| 594 |
max_width = max(widths)
|
embodied_gen/utils/simulation.py
CHANGED
|
@@ -69,6 +69,21 @@ def load_actor_from_urdf(
|
|
| 69 |
update_mass: bool = False,
|
| 70 |
scale: float | np.ndarray = 1.0,
|
| 71 |
) -> sapien.pysapien.Entity:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
| 73 |
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
| 74 |
if origin_tag is not None:
|
|
@@ -154,14 +169,17 @@ def load_assets_from_layout_file(
|
|
| 154 |
init_quat: list[float] = [0, 0, 0, 1],
|
| 155 |
env_idx: int = None,
|
| 156 |
) -> dict[str, sapien.pysapien.Entity]:
|
| 157 |
-
"""Load assets from
|
| 158 |
|
| 159 |
Args:
|
| 160 |
-
scene (sapien.Scene
|
| 161 |
-
layout (str):
|
| 162 |
-
z_offset (float):
|
| 163 |
-
init_quat (
|
| 164 |
-
env_idx (int): Environment index
|
|
|
|
|
|
|
|
|
|
| 165 |
"""
|
| 166 |
asset_root = os.path.dirname(layout)
|
| 167 |
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
|
@@ -206,6 +224,19 @@ def load_mani_skill_robot(
|
|
| 206 |
control_mode: str = "pd_joint_pos",
|
| 207 |
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
| 208 |
) -> BaseAgent:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
from mani_skill.agents import REGISTERED_AGENTS
|
| 210 |
from mani_skill.envs.scene import ManiSkillScene
|
| 211 |
from mani_skill.envs.utils.system.backend import (
|
|
@@ -278,14 +309,14 @@ def render_images(
|
|
| 278 |
]
|
| 279 |
] = None,
|
| 280 |
) -> dict[str, Image.Image]:
|
| 281 |
-
"""Render images from a given
|
| 282 |
|
| 283 |
Args:
|
| 284 |
-
camera (sapien.render.RenderCameraComponent):
|
| 285 |
-
render_keys (
|
| 286 |
|
| 287 |
Returns:
|
| 288 |
-
|
| 289 |
"""
|
| 290 |
if render_keys is None:
|
| 291 |
render_keys = [
|
|
@@ -341,11 +372,33 @@ def render_images(
|
|
| 341 |
|
| 342 |
|
| 343 |
class SapienSceneManager:
|
| 344 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
def __init__(
|
| 347 |
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
| 348 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
self.sim_freq = sim_freq
|
| 350 |
self.ray_tracing = ray_tracing
|
| 351 |
self.device = device
|
|
@@ -355,7 +408,11 @@ class SapienSceneManager:
|
|
| 355 |
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
| 356 |
|
| 357 |
def _setup_scene(self) -> sapien.Scene:
|
| 358 |
-
"""Set up the SAPIEN scene with lighting and ground.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
# Ray tracing settings
|
| 360 |
if self.ray_tracing:
|
| 361 |
sapien.render.set_camera_shader_dir("rt")
|
|
@@ -397,6 +454,18 @@ class SapienSceneManager:
|
|
| 397 |
render_keys: list[str],
|
| 398 |
sim_steps_per_control: int = 1,
|
| 399 |
) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
agent.set_action(action)
|
| 401 |
frames = defaultdict(list)
|
| 402 |
for _ in range(sim_steps_per_control):
|
|
@@ -417,13 +486,13 @@ class SapienSceneManager:
|
|
| 417 |
image_hw: tuple[int, int],
|
| 418 |
fovy_deg: float,
|
| 419 |
) -> sapien.render.RenderCameraComponent:
|
| 420 |
-
"""Create a
|
| 421 |
|
| 422 |
Args:
|
| 423 |
-
cam_name (str):
|
| 424 |
-
pose (sapien.Pose): Camera pose
|
| 425 |
-
image_hw (
|
| 426 |
-
fovy_deg (float): Field of view in degrees
|
| 427 |
|
| 428 |
Returns:
|
| 429 |
sapien.render.RenderCameraComponent: The created camera.
|
|
@@ -456,15 +525,15 @@ class SapienSceneManager:
|
|
| 456 |
"""Initialize multiple cameras arranged in a circle.
|
| 457 |
|
| 458 |
Args:
|
| 459 |
-
num_cameras (int): Number of cameras
|
| 460 |
-
radius (float):
|
| 461 |
-
height (float):
|
| 462 |
-
target_pt (list[float]):
|
| 463 |
-
image_hw (
|
| 464 |
-
fovy_deg (float): Field of view in degrees
|
| 465 |
|
| 466 |
Returns:
|
| 467 |
-
|
| 468 |
"""
|
| 469 |
angle_step = 2 * np.pi / num_cameras
|
| 470 |
world_up_vec = np.array([0.0, 0.0, 1.0])
|
|
@@ -510,6 +579,19 @@ class SapienSceneManager:
|
|
| 510 |
|
| 511 |
|
| 512 |
class FrankaPandaGrasper(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
def __init__(
|
| 514 |
self,
|
| 515 |
agent: BaseAgent,
|
|
@@ -518,6 +600,7 @@ class FrankaPandaGrasper(object):
|
|
| 518 |
joint_acc_limits: float = 1.0,
|
| 519 |
finger_length: float = 0.025,
|
| 520 |
) -> None:
|
|
|
|
| 521 |
self.agent = agent
|
| 522 |
self.robot = agent.robot
|
| 523 |
self.control_freq = control_freq
|
|
@@ -553,6 +636,15 @@ class FrankaPandaGrasper(object):
|
|
| 553 |
gripper_state: Literal[-1, 1],
|
| 554 |
n_step: int = 10,
|
| 555 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
| 557 |
actions = []
|
| 558 |
for _ in range(n_step):
|
|
@@ -571,6 +663,20 @@ class FrankaPandaGrasper(object):
|
|
| 571 |
action_key: str = "position",
|
| 572 |
env_idx: int = 0,
|
| 573 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
result = self.planners[env_idx].plan_qpos_to_pose(
|
| 575 |
np.concatenate([pose.p, pose.q]),
|
| 576 |
self.robot.get_qpos().cpu().numpy()[0],
|
|
@@ -608,6 +714,17 @@ class FrankaPandaGrasper(object):
|
|
| 608 |
offset: tuple[float, float, float] = [0, 0, -0.05],
|
| 609 |
env_idx: int = 0,
|
| 610 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
physx_rigid = actor.components[1]
|
| 612 |
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
| 613 |
obb = mesh.bounding_box_oriented
|
|
|
|
| 69 |
update_mass: bool = False,
|
| 70 |
scale: float | np.ndarray = 1.0,
|
| 71 |
) -> sapien.pysapien.Entity:
|
| 72 |
+
"""Load an sapien actor from a URDF file and add it to the scene.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
| 76 |
+
file_path (str): Path to the URDF file.
|
| 77 |
+
pose (sapien.Pose | None): Initial pose of the actor.
|
| 78 |
+
env_idx (int): Environment index for multi-env setup.
|
| 79 |
+
use_static (bool): Whether the actor is static.
|
| 80 |
+
update_mass (bool): Whether to update the actor's mass from URDF.
|
| 81 |
+
scale (float | np.ndarray): Scale factor for the actor.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
sapien.pysapien.Entity: The created actor entity.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
|
| 88 |
local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
|
| 89 |
if origin_tag is not None:
|
|
|
|
| 169 |
init_quat: list[float] = [0, 0, 0, 1],
|
| 170 |
env_idx: int = None,
|
| 171 |
) -> dict[str, sapien.pysapien.Entity]:
|
| 172 |
+
"""Load assets from an EmbodiedGen layout file and create sapien actors in the scene.
|
| 173 |
|
| 174 |
Args:
|
| 175 |
+
scene (ManiSkillScene | sapien.Scene): The sapien simulation scene.
|
| 176 |
+
layout (str): Path to the embodiedgen layout file.
|
| 177 |
+
z_offset (float): Z offset for non-context objects.
|
| 178 |
+
init_quat (list[float]): Initial quaternion for orientation.
|
| 179 |
+
env_idx (int): Environment index.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities.
|
| 183 |
"""
|
| 184 |
asset_root = os.path.dirname(layout)
|
| 185 |
layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
|
|
|
|
| 224 |
control_mode: str = "pd_joint_pos",
|
| 225 |
backend_str: tuple[str, str] = ("cpu", "gpu"),
|
| 226 |
) -> BaseAgent:
|
| 227 |
+
"""Load a ManiSkill robot agent into the scene.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
scene (sapien.Scene | ManiSkillScene): The simulation scene.
|
| 231 |
+
layout (LayoutInfo | str): Layout info or path to layout file.
|
| 232 |
+
control_freq (int): Control frequency.
|
| 233 |
+
robot_init_qpos_noise (float): Noise for initial joint positions.
|
| 234 |
+
control_mode (str): Robot control mode.
|
| 235 |
+
backend_str (tuple[str, str]): Simulation/render backend.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
BaseAgent: The loaded robot agent.
|
| 239 |
+
"""
|
| 240 |
from mani_skill.agents import REGISTERED_AGENTS
|
| 241 |
from mani_skill.envs.scene import ManiSkillScene
|
| 242 |
from mani_skill.envs.utils.system.backend import (
|
|
|
|
| 309 |
]
|
| 310 |
] = None,
|
| 311 |
) -> dict[str, Image.Image]:
|
| 312 |
+
"""Render images from a given SAPIEN camera.
|
| 313 |
|
| 314 |
Args:
|
| 315 |
+
camera (sapien.render.RenderCameraComponent): Camera to render from.
|
| 316 |
+
render_keys (list[str], optional): Types of images to render.
|
| 317 |
|
| 318 |
Returns:
|
| 319 |
+
dict[str, Image.Image]: Dictionary of rendered images.
|
| 320 |
"""
|
| 321 |
if render_keys is None:
|
| 322 |
render_keys = [
|
|
|
|
| 372 |
|
| 373 |
|
| 374 |
class SapienSceneManager:
|
| 375 |
+
"""Manages SAPIEN simulation scenes, cameras, and rendering.
|
| 376 |
+
|
| 377 |
+
This class provides utilities for setting up scenes, adding cameras,
|
| 378 |
+
stepping simulation, and rendering images.
|
| 379 |
+
|
| 380 |
+
Attributes:
|
| 381 |
+
sim_freq (int): Simulation frequency.
|
| 382 |
+
ray_tracing (bool): Whether to use ray tracing.
|
| 383 |
+
device (str): Device for simulation.
|
| 384 |
+
renderer (sapien.SapienRenderer): SAPIEN renderer.
|
| 385 |
+
scene (sapien.Scene): Simulation scene.
|
| 386 |
+
cameras (list): List of camera components.
|
| 387 |
+
actors (dict): Mapping of actor names to entities.
|
| 388 |
+
|
| 389 |
+
Example see `embodied_gen/scripts/simulate_sapien.py`.
|
| 390 |
+
"""
|
| 391 |
|
| 392 |
def __init__(
|
| 393 |
self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
|
| 394 |
) -> None:
|
| 395 |
+
"""Initialize the scene manager.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
sim_freq (int): Simulation frequency.
|
| 399 |
+
ray_tracing (bool): Enable ray tracing.
|
| 400 |
+
device (str): Device for simulation.
|
| 401 |
+
"""
|
| 402 |
self.sim_freq = sim_freq
|
| 403 |
self.ray_tracing = ray_tracing
|
| 404 |
self.device = device
|
|
|
|
| 408 |
self.actors: dict[str, sapien.pysapien.Entity] = {}
|
| 409 |
|
| 410 |
def _setup_scene(self) -> sapien.Scene:
|
| 411 |
+
"""Set up the SAPIEN scene with lighting and ground.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
sapien.Scene: The initialized scene.
|
| 415 |
+
"""
|
| 416 |
# Ray tracing settings
|
| 417 |
if self.ray_tracing:
|
| 418 |
sapien.render.set_camera_shader_dir("rt")
|
|
|
|
| 454 |
render_keys: list[str],
|
| 455 |
sim_steps_per_control: int = 1,
|
| 456 |
) -> dict:
|
| 457 |
+
"""Step the simulation and render images from cameras.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
agent (BaseAgent): The robot agent.
|
| 461 |
+
action (torch.Tensor): Action to apply.
|
| 462 |
+
cameras (list): List of camera components.
|
| 463 |
+
render_keys (list[str]): Types of images to render.
|
| 464 |
+
sim_steps_per_control (int): Simulation steps per control.
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
dict: Dictionary of rendered frames per camera.
|
| 468 |
+
"""
|
| 469 |
agent.set_action(action)
|
| 470 |
frames = defaultdict(list)
|
| 471 |
for _ in range(sim_steps_per_control):
|
|
|
|
| 486 |
image_hw: tuple[int, int],
|
| 487 |
fovy_deg: float,
|
| 488 |
) -> sapien.render.RenderCameraComponent:
|
| 489 |
+
"""Create a camera in the scene.
|
| 490 |
|
| 491 |
Args:
|
| 492 |
+
cam_name (str): Camera name.
|
| 493 |
+
pose (sapien.Pose): Camera pose.
|
| 494 |
+
image_hw (tuple[int, int]): Image resolution (height, width).
|
| 495 |
+
fovy_deg (float): Field of view in degrees.
|
| 496 |
|
| 497 |
Returns:
|
| 498 |
sapien.render.RenderCameraComponent: The created camera.
|
|
|
|
| 525 |
"""Initialize multiple cameras arranged in a circle.
|
| 526 |
|
| 527 |
Args:
|
| 528 |
+
num_cameras (int): Number of cameras.
|
| 529 |
+
radius (float): Circle radius.
|
| 530 |
+
height (float): Camera height.
|
| 531 |
+
target_pt (list[float]): Target point to look at.
|
| 532 |
+
image_hw (tuple[int, int]): Image resolution.
|
| 533 |
+
fovy_deg (float): Field of view in degrees.
|
| 534 |
|
| 535 |
Returns:
|
| 536 |
+
list[sapien.render.RenderCameraComponent]: List of cameras.
|
| 537 |
"""
|
| 538 |
angle_step = 2 * np.pi / num_cameras
|
| 539 |
world_up_vec = np.array([0.0, 0.0, 1.0])
|
|
|
|
| 579 |
|
| 580 |
|
| 581 |
class FrankaPandaGrasper(object):
|
| 582 |
+
"""Provides grasp planning and control for Franka Panda robot.
|
| 583 |
+
|
| 584 |
+
Attributes:
|
| 585 |
+
agent (BaseAgent): The robot agent.
|
| 586 |
+
robot: The robot instance.
|
| 587 |
+
control_freq (float): Control frequency.
|
| 588 |
+
control_timestep (float): Control timestep.
|
| 589 |
+
joint_vel_limits (float): Joint velocity limits.
|
| 590 |
+
joint_acc_limits (float): Joint acceleration limits.
|
| 591 |
+
finger_length (float): Length of gripper fingers.
|
| 592 |
+
planners: Motion planners for each environment.
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
def __init__(
|
| 596 |
self,
|
| 597 |
agent: BaseAgent,
|
|
|
|
| 600 |
joint_acc_limits: float = 1.0,
|
| 601 |
finger_length: float = 0.025,
|
| 602 |
) -> None:
|
| 603 |
+
"""Initialize the grasper."""
|
| 604 |
self.agent = agent
|
| 605 |
self.robot = agent.robot
|
| 606 |
self.control_freq = control_freq
|
|
|
|
| 636 |
gripper_state: Literal[-1, 1],
|
| 637 |
n_step: int = 10,
|
| 638 |
) -> np.ndarray:
|
| 639 |
+
"""Generate gripper control actions.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
gripper_state (Literal[-1, 1]): Desired gripper state.
|
| 643 |
+
n_step (int): Number of steps.
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
np.ndarray: Array of gripper actions.
|
| 647 |
+
"""
|
| 648 |
qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
|
| 649 |
actions = []
|
| 650 |
for _ in range(n_step):
|
|
|
|
| 663 |
action_key: str = "position",
|
| 664 |
env_idx: int = 0,
|
| 665 |
) -> np.ndarray:
|
| 666 |
+
"""Plan and execute motion to a target pose.
|
| 667 |
+
|
| 668 |
+
Args:
|
| 669 |
+
pose (sapien.Pose): Target pose.
|
| 670 |
+
control_timestep (float): Control timestep.
|
| 671 |
+
gripper_state (Literal[-1, 1]): Desired gripper state.
|
| 672 |
+
use_point_cloud (bool): Use point cloud for planning.
|
| 673 |
+
n_max_step (int): Max number of steps.
|
| 674 |
+
action_key (str): Key for action in result.
|
| 675 |
+
env_idx (int): Environment index.
|
| 676 |
+
|
| 677 |
+
Returns:
|
| 678 |
+
np.ndarray: Array of actions to reach the pose.
|
| 679 |
+
"""
|
| 680 |
result = self.planners[env_idx].plan_qpos_to_pose(
|
| 681 |
np.concatenate([pose.p, pose.q]),
|
| 682 |
self.robot.get_qpos().cpu().numpy()[0],
|
|
|
|
| 714 |
offset: tuple[float, float, float] = [0, 0, -0.05],
|
| 715 |
env_idx: int = 0,
|
| 716 |
) -> np.ndarray:
|
| 717 |
+
"""Compute grasp actions for a target actor.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
actor (sapien.pysapien.Entity): Target actor to grasp.
|
| 721 |
+
reach_target_only (bool): Only reach the target pose if True.
|
| 722 |
+
offset (tuple[float, float, float]): Offset for reach pose.
|
| 723 |
+
env_idx (int): Environment index.
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
np.ndarray: Array of grasp actions.
|
| 727 |
+
"""
|
| 728 |
physx_rigid = actor.components[1]
|
| 729 |
mesh = get_component_mesh(physx_rigid, to_world_frame=True)
|
| 730 |
obb = mesh.bounding_box_oriented
|
embodied_gen/utils/tags.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
VERSION = "v0.1.
|
|
|
|
| 1 |
+
VERSION = "v0.1.6"
|
embodied_gen/validators/aesthetic_predictor.py
CHANGED
|
@@ -27,14 +27,22 @@ from PIL import Image
|
|
| 27 |
|
| 28 |
|
| 29 |
class AestheticPredictor:
|
| 30 |
-
"""Aesthetic Score Predictor.
|
| 31 |
|
| 32 |
-
Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
|
| 33 |
|
| 34 |
Args:
|
| 35 |
-
clip_model_dir (str): Path to
|
| 36 |
-
sac_model_path (str): Path to
|
| 37 |
-
device (str): Device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"""
|
| 39 |
|
| 40 |
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
|
@@ -109,7 +117,7 @@ class AestheticPredictor:
|
|
| 109 |
return model
|
| 110 |
|
| 111 |
def predict(self, image_path):
|
| 112 |
-
"""
|
| 113 |
|
| 114 |
Args:
|
| 115 |
image_path (str): Path to the image file.
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class AestheticPredictor:
|
| 30 |
+
"""Aesthetic Score Predictor using CLIP and a pre-trained MLP.
|
| 31 |
|
| 32 |
+
Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`.
|
| 33 |
|
| 34 |
Args:
|
| 35 |
+
clip_model_dir (str, optional): Path to CLIP model directory.
|
| 36 |
+
sac_model_path (str, optional): Path to SAC model weights.
|
| 37 |
+
device (str, optional): Device for computation ("cuda" or "cpu").
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
```py
|
| 41 |
+
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
| 42 |
+
predictor = AestheticPredictor(device="cuda")
|
| 43 |
+
score = predictor.predict("image.png")
|
| 44 |
+
print("Aesthetic score:", score)
|
| 45 |
+
```
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
|
|
|
|
| 117 |
return model
|
| 118 |
|
| 119 |
def predict(self, image_path):
|
| 120 |
+
"""Predicts the aesthetic score for a given image.
|
| 121 |
|
| 122 |
Args:
|
| 123 |
image_path (str): Path to the image file.
|
embodied_gen/validators/quality_checkers.py
CHANGED
|
@@ -40,6 +40,16 @@ __all__ = [
|
|
| 40 |
|
| 41 |
|
| 42 |
class BaseChecker:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
| 44 |
self.prompt = prompt
|
| 45 |
self.verbose = verbose
|
|
@@ -70,6 +80,15 @@ class BaseChecker:
|
|
| 70 |
def validate(
|
| 71 |
checkers: list["BaseChecker"], images_list: list[list[str]]
|
| 72 |
) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
assert len(checkers) == len(images_list)
|
| 74 |
results = []
|
| 75 |
overall_result = True
|
|
@@ -192,7 +211,7 @@ class ImageSegChecker(BaseChecker):
|
|
| 192 |
|
| 193 |
|
| 194 |
class ImageAestheticChecker(BaseChecker):
|
| 195 |
-
"""
|
| 196 |
|
| 197 |
Attributes:
|
| 198 |
clip_model_dir (str): Path to the CLIP model directory.
|
|
@@ -200,6 +219,14 @@ class ImageAestheticChecker(BaseChecker):
|
|
| 200 |
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
| 201 |
verbose (bool): Whether to print detailed log messages.
|
| 202 |
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
"""
|
| 204 |
|
| 205 |
def __init__(
|
|
@@ -227,6 +254,16 @@ class ImageAestheticChecker(BaseChecker):
|
|
| 227 |
|
| 228 |
|
| 229 |
class SemanticConsistChecker(BaseChecker):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
def __init__(
|
| 231 |
self,
|
| 232 |
gpt_client: GPTclient,
|
|
@@ -276,6 +313,16 @@ class SemanticConsistChecker(BaseChecker):
|
|
| 276 |
|
| 277 |
|
| 278 |
class TextGenAlignChecker(BaseChecker):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
def __init__(
|
| 280 |
self,
|
| 281 |
gpt_client: GPTclient,
|
|
@@ -489,6 +536,17 @@ class PanoHeightEstimator(object):
|
|
| 489 |
|
| 490 |
|
| 491 |
class SemanticMatcher(BaseChecker):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
def __init__(
|
| 493 |
self,
|
| 494 |
gpt_client: GPTclient,
|
|
@@ -543,6 +601,17 @@ class SemanticMatcher(BaseChecker):
|
|
| 543 |
def query(
|
| 544 |
self, text: str, context: dict, rand: bool = True, params: dict = None
|
| 545 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
match_list = self.gpt_client.query(
|
| 547 |
self.prompt.format(context=context, text=text),
|
| 548 |
params=params,
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
class BaseChecker:
|
| 43 |
+
"""Base class for quality checkers using GPT clients.
|
| 44 |
+
|
| 45 |
+
Provides a common interface for querying and validating responses.
|
| 46 |
+
Subclasses must implement the `query` method.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
prompt (str): The prompt used for queries.
|
| 50 |
+
verbose (bool): Whether to enable verbose logging.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
| 54 |
self.prompt = prompt
|
| 55 |
self.verbose = verbose
|
|
|
|
| 80 |
def validate(
|
| 81 |
checkers: list["BaseChecker"], images_list: list[list[str]]
|
| 82 |
) -> list:
|
| 83 |
+
"""Validates a list of checkers against corresponding image lists.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
checkers (list[BaseChecker]): List of checker instances.
|
| 87 |
+
images_list (list[list[str]]): List of image path lists.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
list: Validation results with overall outcome.
|
| 91 |
+
"""
|
| 92 |
assert len(checkers) == len(images_list)
|
| 93 |
results = []
|
| 94 |
overall_result = True
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
class ImageAestheticChecker(BaseChecker):
|
| 214 |
+
"""Evaluates the aesthetic quality of images using a CLIP-based predictor.
|
| 215 |
|
| 216 |
Attributes:
|
| 217 |
clip_model_dir (str): Path to the CLIP model directory.
|
|
|
|
| 219 |
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
| 220 |
verbose (bool): Whether to print detailed log messages.
|
| 221 |
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
| 222 |
+
|
| 223 |
+
Example:
|
| 224 |
+
```py
|
| 225 |
+
from embodied_gen.validators.quality_checkers import ImageAestheticChecker
|
| 226 |
+
checker = ImageAestheticChecker(thresh=4.5)
|
| 227 |
+
flag, score = checker(["image1.png", "image2.png"])
|
| 228 |
+
print("Aesthetic OK:", flag, "Score:", score)
|
| 229 |
+
```
|
| 230 |
"""
|
| 231 |
|
| 232 |
def __init__(
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
class SemanticConsistChecker(BaseChecker):
|
| 257 |
+
"""Checks semantic consistency between text descriptions and segmented images.
|
| 258 |
+
|
| 259 |
+
Uses GPT to evaluate if the image matches the text in object type, geometry, and color.
|
| 260 |
+
|
| 261 |
+
Attributes:
|
| 262 |
+
gpt_client (GPTclient): GPT client for queries.
|
| 263 |
+
prompt (str): Prompt for consistency evaluation.
|
| 264 |
+
verbose (bool): Whether to enable verbose logging.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
def __init__(
|
| 268 |
self,
|
| 269 |
gpt_client: GPTclient,
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
class TextGenAlignChecker(BaseChecker):
|
| 316 |
+
"""Evaluates alignment between text prompts and generated 3D asset images.
|
| 317 |
+
|
| 318 |
+
Assesses if the rendered images match the text description in category and geometry.
|
| 319 |
+
|
| 320 |
+
Attributes:
|
| 321 |
+
gpt_client (GPTclient): GPT client for queries.
|
| 322 |
+
prompt (str): Prompt for alignment evaluation.
|
| 323 |
+
verbose (bool): Whether to enable verbose logging.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
def __init__(
|
| 327 |
self,
|
| 328 |
gpt_client: GPTclient,
|
|
|
|
| 536 |
|
| 537 |
|
| 538 |
class SemanticMatcher(BaseChecker):
|
| 539 |
+
"""Matches query text to semantically similar scene descriptions.
|
| 540 |
+
|
| 541 |
+
Uses GPT to find the most similar scene IDs from a dictionary.
|
| 542 |
+
|
| 543 |
+
Attributes:
|
| 544 |
+
gpt_client (GPTclient): GPT client for queries.
|
| 545 |
+
prompt (str): Prompt for semantic matching.
|
| 546 |
+
verbose (bool): Whether to enable verbose logging.
|
| 547 |
+
seed (int): Random seed for selection.
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
def __init__(
|
| 551 |
self,
|
| 552 |
gpt_client: GPTclient,
|
|
|
|
| 601 |
def query(
|
| 602 |
self, text: str, context: dict, rand: bool = True, params: dict = None
|
| 603 |
) -> str:
|
| 604 |
+
"""Queries for semantically similar scene IDs.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
text (str): Query text.
|
| 608 |
+
context (dict): Dictionary of scene descriptions.
|
| 609 |
+
rand (bool, optional): Whether to randomly select from top matches.
|
| 610 |
+
params (dict, optional): Additional GPT parameters.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
str: Matched scene ID.
|
| 614 |
+
"""
|
| 615 |
match_list = self.gpt_client.query(
|
| 616 |
self.prompt.format(context=context, text=text),
|
| 617 |
params=params,
|
embodied_gen/validators/urdf_convertor.py
CHANGED
|
@@ -80,6 +80,31 @@ URDF_TEMPLATE = """
|
|
| 80 |
|
| 81 |
|
| 82 |
class URDFGenerator(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def __init__(
|
| 84 |
self,
|
| 85 |
gpt_client: GPTclient,
|
|
@@ -168,6 +193,14 @@ class URDFGenerator(object):
|
|
| 168 |
self.rotate_xyzw = rotate_xyzw
|
| 169 |
|
| 170 |
def parse_response(self, response: str) -> dict[str, any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
lines = response.split("\n")
|
| 172 |
lines = [line.strip() for line in lines if line]
|
| 173 |
category = lines[0].split(": ")[1]
|
|
@@ -207,11 +240,9 @@ class URDFGenerator(object):
|
|
| 207 |
|
| 208 |
Args:
|
| 209 |
input_mesh (str): Path to the input mesh file.
|
| 210 |
-
output_dir (str): Directory to store the generated URDF
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
mass, and friction coefficients.
|
| 214 |
-
output_name (str, optional): Name for the generated URDF and robot.
|
| 215 |
|
| 216 |
Returns:
|
| 217 |
str: Path to the generated URDF file.
|
|
@@ -336,6 +367,16 @@ class URDFGenerator(object):
|
|
| 336 |
attr_root: str = ".//link/extra_info",
|
| 337 |
attr_name: str = "scale",
|
| 338 |
) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
if not os.path.exists(urdf_path):
|
| 340 |
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
| 341 |
|
|
@@ -358,6 +399,13 @@ class URDFGenerator(object):
|
|
| 358 |
def add_quality_tag(
|
| 359 |
urdf_path: str, results: list, output_path: str = None
|
| 360 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
if output_path is None:
|
| 362 |
output_path = urdf_path
|
| 363 |
|
|
@@ -382,6 +430,14 @@ class URDFGenerator(object):
|
|
| 382 |
logger.info(f"URDF files saved to {output_path}")
|
| 383 |
|
| 384 |
def get_estimated_attributes(self, asset_attrs: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
estimated_attrs = {
|
| 386 |
"height": round(
|
| 387 |
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
|
@@ -403,6 +459,18 @@ class URDFGenerator(object):
|
|
| 403 |
category: str = "unknown",
|
| 404 |
**kwargs,
|
| 405 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
if text_prompt is None or len(text_prompt) == 0:
|
| 407 |
text_prompt = self.prompt_template
|
| 408 |
text_prompt = text_prompt.format(category=category.lower())
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
class URDFGenerator(object):
|
| 83 |
+
"""Generates URDF files for 3D assets with physical and semantic attributes.
|
| 84 |
+
|
| 85 |
+
Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
gpt_client (GPTclient): GPT client for attribute estimation.
|
| 89 |
+
mesh_file_list (list[str], optional): Additional mesh files to copy.
|
| 90 |
+
prompt_template (str, optional): Prompt template for GPT queries.
|
| 91 |
+
attrs_name (list[str], optional): List of attribute names to include.
|
| 92 |
+
render_dir (str, optional): Directory for rendered images.
|
| 93 |
+
render_view_num (int, optional): Number of views to render.
|
| 94 |
+
decompose_convex (bool, optional): Whether to decompose mesh for collision.
|
| 95 |
+
rotate_xyzw (list[float], optional): Quaternion for mesh rotation.
|
| 96 |
+
|
| 97 |
+
Example:
|
| 98 |
+
```py
|
| 99 |
+
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
| 100 |
+
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 101 |
+
|
| 102 |
+
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
| 103 |
+
urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir")
|
| 104 |
+
print("Generated URDF:", urdf_path)
|
| 105 |
+
```
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
def __init__(
|
| 109 |
self,
|
| 110 |
gpt_client: GPTclient,
|
|
|
|
| 193 |
self.rotate_xyzw = rotate_xyzw
|
| 194 |
|
| 195 |
def parse_response(self, response: str) -> dict[str, any]:
|
| 196 |
+
"""Parses GPT response to extract asset attributes.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
response (str): GPT response string.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
dict[str, any]: Parsed attributes.
|
| 203 |
+
"""
|
| 204 |
lines = response.split("\n")
|
| 205 |
lines = [line.strip() for line in lines if line]
|
| 206 |
category = lines[0].split(": ")[1]
|
|
|
|
| 240 |
|
| 241 |
Args:
|
| 242 |
input_mesh (str): Path to the input mesh file.
|
| 243 |
+
output_dir (str): Directory to store the generated URDF and mesh.
|
| 244 |
+
attr_dict (dict): Dictionary of asset attributes.
|
| 245 |
+
output_name (str, optional): Name for the URDF and robot.
|
|
|
|
|
|
|
| 246 |
|
| 247 |
Returns:
|
| 248 |
str: Path to the generated URDF file.
|
|
|
|
| 367 |
attr_root: str = ".//link/extra_info",
|
| 368 |
attr_name: str = "scale",
|
| 369 |
) -> float:
|
| 370 |
+
"""Extracts an attribute value from a URDF file.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
urdf_path (str): Path to the URDF file.
|
| 374 |
+
attr_root (str, optional): XML path to attribute root.
|
| 375 |
+
attr_name (str, optional): Attribute name.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
float: Attribute value, or None if not found.
|
| 379 |
+
"""
|
| 380 |
if not os.path.exists(urdf_path):
|
| 381 |
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
| 382 |
|
|
|
|
| 399 |
def add_quality_tag(
|
| 400 |
urdf_path: str, results: list, output_path: str = None
|
| 401 |
) -> None:
|
| 402 |
+
"""Adds a quality tag to a URDF file.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
urdf_path (str): Path to the URDF file.
|
| 406 |
+
results (list): List of [checker_name, result] pairs.
|
| 407 |
+
output_path (str, optional): Output file path.
|
| 408 |
+
"""
|
| 409 |
if output_path is None:
|
| 410 |
output_path = urdf_path
|
| 411 |
|
|
|
|
| 430 |
logger.info(f"URDF files saved to {output_path}")
|
| 431 |
|
| 432 |
def get_estimated_attributes(self, asset_attrs: dict):
|
| 433 |
+
"""Calculates estimated attributes from asset properties.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
asset_attrs (dict): Asset attributes.
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
dict: Estimated attributes (height, mass, mu, category).
|
| 440 |
+
"""
|
| 441 |
estimated_attrs = {
|
| 442 |
"height": round(
|
| 443 |
(asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
|
|
|
|
| 459 |
category: str = "unknown",
|
| 460 |
**kwargs,
|
| 461 |
):
|
| 462 |
+
"""Generates a URDF file for a mesh asset.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
mesh_path (str): Path to mesh file.
|
| 466 |
+
output_root (str): Directory for outputs.
|
| 467 |
+
text_prompt (str, optional): Prompt for GPT.
|
| 468 |
+
category (str, optional): Asset category.
|
| 469 |
+
**kwargs: Additional attributes.
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
str: Path to generated URDF file.
|
| 473 |
+
"""
|
| 474 |
if text_prompt is None or len(text_prompt) == 0:
|
| 475 |
text_prompt = self.prompt_template
|
| 476 |
text_prompt = text_prompt.format(category=category.lower())
|
thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py
CHANGED
|
@@ -440,7 +440,7 @@ def to_glb(
|
|
| 440 |
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
|
| 442 |
# bake texture
|
| 443 |
-
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=
|
| 444 |
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
|
|
|
| 440 |
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
|
| 442 |
# bake texture
|
| 443 |
+
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=200)
|
| 444 |
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|