xinjie.wang commited on
Commit
b800513
·
1 Parent(s): 0533bc0
app.py CHANGED
@@ -27,7 +27,7 @@ from common import (
27
  VERSION,
28
  active_btn_by_content,
29
  end_session,
30
- extract_3d_representations_v2,
31
  extract_urdf,
32
  get_seed,
33
  image_to_3d,
@@ -179,17 +179,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
179
  )
180
 
181
  generate_btn = gr.Button(
182
- "🚀 1. Generate(~0.5 mins)",
183
  variant="primary",
184
  interactive=False,
185
  )
186
  model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
187
- with gr.Row():
188
- extract_rep3d_btn = gr.Button(
189
- "🔍 2. Extract 3D Representation(~2 mins)",
190
- variant="primary",
191
- interactive=False,
192
- )
193
  with gr.Accordion(
194
  label="Enter Asset Attributes(optional)", open=False
195
  ):
@@ -207,7 +207,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
207
  )
208
  with gr.Row():
209
  extract_urdf_btn = gr.Button(
210
- "🧩 3. Extract URDF with physics(~1 mins)",
211
  variant="primary",
212
  interactive=False,
213
  )
@@ -230,7 +230,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
230
  )
231
  with gr.Row():
232
  download_urdf = gr.DownloadButton(
233
- label="⬇️ 4. Download URDF",
234
  variant="primary",
235
  interactive=False,
236
  )
@@ -326,7 +326,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
326
  image_prompt.change(
327
  lambda: tuple(
328
  [
329
- gr.Button(interactive=False),
330
  gr.Button(interactive=False),
331
  gr.Button(interactive=False),
332
  None,
@@ -344,7 +344,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
344
  ]
345
  ),
346
  outputs=[
347
- extract_rep3d_btn,
348
  extract_urdf_btn,
349
  download_urdf,
350
  model_output_gs,
@@ -375,7 +375,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
375
  image_prompt_sam.change(
376
  lambda: tuple(
377
  [
378
- gr.Button(interactive=False),
379
  gr.Button(interactive=False),
380
  gr.Button(interactive=False),
381
  None,
@@ -394,7 +394,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
394
  ]
395
  ),
396
  outputs=[
397
- extract_rep3d_btn,
398
  extract_urdf_btn,
399
  download_urdf,
400
  model_output_gs,
@@ -447,12 +447,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
447
  ],
448
  outputs=[output_buf, video_output],
449
  ).success(
450
- lambda: gr.Button(interactive=True),
451
- outputs=[extract_rep3d_btn],
452
- )
453
-
454
- extract_rep3d_btn.click(
455
- extract_3d_representations_v2,
456
  inputs=[
457
  output_buf,
458
  project_delight,
 
27
  VERSION,
28
  active_btn_by_content,
29
  end_session,
30
+ extract_3d_representations_v3,
31
  extract_urdf,
32
  get_seed,
33
  image_to_3d,
 
179
  )
180
 
181
  generate_btn = gr.Button(
182
+ "🚀 1. Generate(~2 mins)",
183
  variant="primary",
184
  interactive=False,
185
  )
186
  model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
187
+ # with gr.Row():
188
+ # extract_rep3d_btn = gr.Button(
189
+ # "🔍 2. Extract 3D Representation(~2 mins)",
190
+ # variant="primary",
191
+ # interactive=False,
192
+ # )
193
  with gr.Accordion(
194
  label="Enter Asset Attributes(optional)", open=False
195
  ):
 
207
  )
208
  with gr.Row():
209
  extract_urdf_btn = gr.Button(
210
+ "🧩 2. Extract URDF with physics(~1 mins)",
211
  variant="primary",
212
  interactive=False,
213
  )
 
230
  )
231
  with gr.Row():
232
  download_urdf = gr.DownloadButton(
233
+ label="⬇️ 3. Download URDF",
234
  variant="primary",
235
  interactive=False,
236
  )
 
326
  image_prompt.change(
327
  lambda: tuple(
328
  [
329
+ # gr.Button(interactive=False),
330
  gr.Button(interactive=False),
331
  gr.Button(interactive=False),
332
  None,
 
344
  ]
345
  ),
346
  outputs=[
347
+ # extract_rep3d_btn,
348
  extract_urdf_btn,
349
  download_urdf,
350
  model_output_gs,
 
375
  image_prompt_sam.change(
376
  lambda: tuple(
377
  [
378
+ # gr.Button(interactive=False),
379
  gr.Button(interactive=False),
380
  gr.Button(interactive=False),
381
  None,
 
394
  ]
395
  ),
396
  outputs=[
397
+ # extract_rep3d_btn,
398
  extract_urdf_btn,
399
  download_urdf,
400
  model_output_gs,
 
447
  ],
448
  outputs=[output_buf, video_output],
449
  ).success(
450
+ extract_3d_representations_v3,
 
 
 
 
 
451
  inputs=[
452
  output_buf,
453
  project_delight,
app_style.py CHANGED
@@ -4,7 +4,7 @@ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
4
  lighting_css = """
5
  <style>
6
  #lighter_mesh canvas {
7
- filter: brightness(1.9) !important;
8
  }
9
  </style>
10
  """
 
4
  lighting_css = """
5
  <style>
6
  #lighter_mesh canvas {
7
+ filter: brightness(2.0) !important;
8
  }
9
  </style>
10
  """
common.py CHANGED
@@ -32,6 +32,7 @@ import trimesh
32
  from easydict import EasyDict as edict
33
  from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
 
35
  from embodied_gen.data.differentiable_render import entrypoint as render_api
36
  from embodied_gen.data.utils import trellis_preprocess, zip_files
37
  from embodied_gen.models.delight_model import DelightingModel
@@ -131,8 +132,8 @@ def patched_setup_functions(self):
131
  Gaussian.setup_functions = patched_setup_functions
132
 
133
 
134
- DELIGHT = DelightingModel()
135
- IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
136
  # IMAGESR_MODEL = ImageStableSR()
137
  if os.getenv("GRADIO_APP") == "imageto3d":
138
  RBG_REMOVER = RembgRemover()
@@ -169,6 +170,8 @@ elif os.getenv("GRADIO_APP") == "textto3d":
169
  )
170
  os.makedirs(TMP_DIR, exist_ok=True)
171
  elif os.getenv("GRADIO_APP") == "texture_edit":
 
 
172
  PIPELINE_IP = build_texture_gen_pipe(
173
  base_ckpt_dir="./weights",
174
  ip_adapt_scale=0.7,
@@ -512,6 +515,60 @@ def extract_3d_representations_v2(
512
  return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
513
 
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  def extract_urdf(
516
  gs_path: str,
517
  mesh_obj_path: str,
 
32
  from easydict import EasyDict as edict
33
  from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
+ from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
36
  from embodied_gen.data.differentiable_render import entrypoint as render_api
37
  from embodied_gen.data.utils import trellis_preprocess, zip_files
38
  from embodied_gen.models.delight_model import DelightingModel
 
132
  Gaussian.setup_functions = patched_setup_functions
133
 
134
 
135
+ # DELIGHT = DelightingModel()
136
+ # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
137
  # IMAGESR_MODEL = ImageStableSR()
138
  if os.getenv("GRADIO_APP") == "imageto3d":
139
  RBG_REMOVER = RembgRemover()
 
170
  )
171
  os.makedirs(TMP_DIR, exist_ok=True)
172
  elif os.getenv("GRADIO_APP") == "texture_edit":
173
+ DELIGHT = DelightingModel()
174
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
175
  PIPELINE_IP = build_texture_gen_pipe(
176
  base_ckpt_dir="./weights",
177
  ip_adapt_scale=0.7,
 
515
  return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
516
 
517
 
518
+ def extract_3d_representations_v3(
519
+ state: dict,
520
+ enable_delight: bool,
521
+ texture_size: int,
522
+ req: gr.Request,
523
+ ):
524
+ output_root = TMP_DIR
525
+ user_dir = os.path.join(output_root, str(req.session_hash))
526
+ gs_model, mesh_model = unpack_state(state, device="cpu")
527
+
528
+ filename = "sample"
529
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
530
+ gs_model.save_ply(gs_path)
531
+
532
+ # Rotate mesh and GS by 90 degrees around Z-axis.
533
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
534
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
535
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
536
+
537
+ # Addtional rotation for GS to align mesh.
538
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
539
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
540
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
541
+ GaussianOperator.resave_ply(
542
+ in_ply=gs_path,
543
+ out_ply=aligned_gs_path,
544
+ instance_pose=pose,
545
+ device="cpu",
546
+ )
547
+
548
+ mesh = trimesh.Trimesh(
549
+ vertices=mesh_model.vertices.cpu().numpy(),
550
+ faces=mesh_model.faces.cpu().numpy(),
551
+ )
552
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
553
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
554
+
555
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
556
+ mesh.export(mesh_obj_path)
557
+
558
+ mesh = backproject_api_v3(
559
+ gs_path=aligned_gs_path,
560
+ mesh_path=mesh_obj_path,
561
+ output_path=mesh_obj_path,
562
+ skip_fix_mesh=False,
563
+ texture_size=texture_size,
564
+ )
565
+
566
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
567
+ mesh.export(mesh_glb_path)
568
+
569
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
570
+
571
+
572
  def extract_urdf(
573
  gs_path: str,
574
  mesh_obj_path: str,
embodied_gen/data/asset_converter.py CHANGED
@@ -4,12 +4,12 @@ import logging
4
  import os
5
  import xml.etree.ElementTree as ET
6
  from abc import ABC, abstractmethod
7
- from dataclasses import dataclass
8
  from glob import glob
9
- from shutil import copy
10
 
11
  import trimesh
12
  from scipy.spatial.transform import Rotation
 
13
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
@@ -17,54 +17,157 @@ logger = logging.getLogger(__name__)
17
 
18
  __all__ = [
19
  "AssetConverterFactory",
20
- "AssetType",
21
  "MeshtoMJCFConverter",
22
  "MeshtoUSDConverter",
23
  "URDFtoUSDConverter",
 
 
24
  ]
25
 
26
 
27
- @dataclass
28
- class AssetType(str):
29
- """Asset type enumeration."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- MJCF = "mjcf"
32
- USD = "usd"
33
- URDF = "urdf"
34
- MESH = "mesh"
35
 
36
 
37
  class AssetConverterBase(ABC):
38
- """Converter abstract base class."""
 
 
 
39
 
40
  @abstractmethod
41
  def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
 
 
 
 
 
 
 
 
 
 
42
  pass
43
 
44
  def transform_mesh(
45
  self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
46
  ) -> None:
47
- """Apply transform to the mesh based on the origin element in URDF."""
48
- mesh = trimesh.load(input_mesh)
 
 
 
 
 
 
49
  rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
50
  rotation = Rotation.from_euler("xyz", rpy, degrees=False)
51
  offset = list(map(float, mesh_origin.get("xyz").split(" ")))
52
- mesh.vertices = (mesh.vertices @ rotation.as_matrix().T) + offset
53
-
54
  os.makedirs(os.path.dirname(output_mesh), exist_ok=True)
55
- _ = mesh.export(output_mesh)
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  return
58
 
59
  def __enter__(self):
 
60
  return self
61
 
62
  def __exit__(self, exc_type, exc_val, exc_tb):
 
63
  return False
64
 
65
 
66
  class MeshtoMJCFConverter(AssetConverterBase):
67
- """Convert URDF files into MJCF format."""
 
 
 
68
 
69
  def __init__(
70
  self,
@@ -73,6 +176,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
73
  self.kwargs = kwargs
74
 
75
  def _copy_asset_file(self, src: str, dst: str) -> None:
 
 
 
 
 
 
76
  if os.path.exists(dst):
77
  return
78
  os.makedirs(os.path.dirname(dst), exist_ok=True)
@@ -90,37 +199,66 @@ class MeshtoMJCFConverter(AssetConverterBase):
90
  material: ET.Element | None = None,
91
  is_collision: bool = False,
92
  ) -> None:
93
- """Add geometry to the MJCF body from the URDF link."""
 
 
 
 
 
 
 
 
 
 
 
 
94
  element = link.find(tag)
95
  geometry = element.find("geometry")
96
  mesh = geometry.find("mesh")
97
  filename = mesh.get("filename")
98
  scale = mesh.get("scale", "1.0 1.0 1.0")
99
-
100
- mesh_asset = ET.SubElement(
101
- mujoco_element, "mesh", name=mesh_name, file=filename, scale=scale
102
- )
103
- geom = ET.SubElement(body, "geom", type="mesh", mesh=mesh_name)
104
-
105
- self._copy_asset_file(
106
- f"{input_dir}/{filename}",
107
- f"{output_dir}/{filename}",
108
- )
109
-
110
- # Preprocess the mesh by applying rotation.
111
  input_mesh = f"{input_dir}/{filename}"
112
  output_mesh = f"{output_dir}/{filename}"
 
 
113
  mesh_origin = element.find("origin")
114
  if mesh_origin is not None:
115
  self.transform_mesh(input_mesh, output_mesh, mesh_origin)
116
 
117
- if material is not None:
118
- geom.set("material", material.get("name"))
119
-
120
  if is_collision:
121
- geom.set("contype", "1")
122
- geom.set("conaffinity", "1")
123
- geom.set("rgba", "1 1 1 0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def add_materials(
126
  self,
@@ -132,31 +270,52 @@ class MeshtoMJCFConverter(AssetConverterBase):
132
  name: str,
133
  reflectance: float = 0.2,
134
  ) -> ET.Element:
135
- """Add materials to the MJCF asset from the URDF link."""
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  element = link.find(tag)
137
  geometry = element.find("geometry")
138
  mesh = geometry.find("mesh")
139
  filename = mesh.get("filename")
140
  dirname = os.path.dirname(filename)
141
-
142
- material = ET.SubElement(
143
- mujoco_element,
144
- "material",
145
- name=f"material_{name}",
146
- texture=f"texture_{name}",
147
- reflectance=str(reflectance),
148
- )
149
-
150
  for path in glob(f"{input_dir}/{dirname}/*.png"):
151
  file_name = os.path.basename(path)
 
 
 
 
 
 
 
 
152
  self._copy_asset_file(
153
  path,
154
  f"{output_dir}/{dirname}/{file_name}",
155
  )
 
 
 
 
 
 
 
 
156
  ET.SubElement(
157
  mujoco_element,
158
  "texture",
159
- name=f"texture_{name}_{os.path.splitext(file_name)[0]}",
160
  type="2d",
161
  file=f"{dirname}/{file_name}",
162
  )
@@ -164,7 +323,12 @@ class MeshtoMJCFConverter(AssetConverterBase):
164
  return material
165
 
166
  def convert(self, urdf_path: str, mjcf_path: str):
167
- """Convert a URDF file to MJCF format."""
 
 
 
 
 
168
  tree = ET.parse(urdf_path)
169
  root = tree.getroot()
170
 
@@ -188,6 +352,7 @@ class MeshtoMJCFConverter(AssetConverterBase):
188
  output_dir,
189
  name=str(idx),
190
  )
 
191
  self.add_geometry(
192
  mujoco_asset,
193
  link,
@@ -217,58 +382,22 @@ class MeshtoMJCFConverter(AssetConverterBase):
217
 
218
 
219
  class URDFtoMJCFConverter(MeshtoMJCFConverter):
220
- """Convert URDF files with joints to MJCF format, handling transformations from joints."""
221
 
222
- def add_materials(
223
- self,
224
- mujoco_element: ET.Element,
225
- link: ET.Element,
226
- tag: str,
227
- input_dir: str,
228
- output_dir: str,
229
- name: str,
230
- reflectance: float = 0.2,
231
- ) -> ET.Element:
232
- """Add materials to the MJCF asset from the URDF link."""
233
- element = link.find(tag)
234
- geometry = element.find("geometry")
235
- mesh = geometry.find("mesh")
236
- filename = mesh.get("filename")
237
- dirname = os.path.dirname(filename)
238
 
239
- diffuse_texture = None
240
- for path in glob(f"{input_dir}/{dirname}/*.png"):
241
- file_name = os.path.basename(path)
242
- self._copy_asset_file(
243
- path,
244
- f"{output_dir}/{dirname}/{file_name}",
245
- )
246
- texture_name = f"texture_{name}_{os.path.splitext(file_name)[0]}"
247
- ET.SubElement(
248
- mujoco_element,
249
- "texture",
250
- name=texture_name,
251
- type="2d",
252
- file=f"{dirname}/{file_name}",
253
- )
254
- if "diffuse" in file_name.lower():
255
- diffuse_texture = texture_name
256
-
257
- if diffuse_texture is None:
258
- return None
259
-
260
- material = ET.SubElement(
261
- mujoco_element,
262
- "material",
263
- name=f"material_{name}",
264
- texture=diffuse_texture,
265
- reflectance=str(reflectance),
266
- )
267
 
268
- return material
 
 
 
269
 
270
- def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
271
- """Convert a URDF file with joints to MJCF format."""
 
272
  tree = ET.parse(urdf_path)
273
  root = tree.getroot()
274
 
@@ -281,18 +410,12 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
281
  output_dir = os.path.dirname(mjcf_path)
282
  os.makedirs(output_dir, exist_ok=True)
283
 
284
- # Create a dictionary to store body elements for each link
285
  body_dict = {}
286
-
287
- # Process all links first
288
  for idx, link in enumerate(root.findall("link")):
289
  link_name = link.get("name", f"unnamed_link_{idx}")
290
  body = ET.SubElement(mujoco_worldbody, "body", name=link_name)
291
  body_dict[link_name] = body
292
-
293
- # Add materials and geometry
294
- visual_element = link.find("visual")
295
- if visual_element is not None:
296
  material = self.add_materials(
297
  mujoco_asset,
298
  link,
@@ -311,9 +434,7 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
311
  f"visual_mesh_{idx}",
312
  material,
313
  )
314
-
315
- collision_element = link.find("collision")
316
- if collision_element is not None:
317
  self.add_geometry(
318
  mujoco_asset,
319
  link,
@@ -329,41 +450,27 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
329
  for joint in root.findall("joint"):
330
  joint_type = joint.get("type")
331
  if joint_type != "fixed":
332
- logger.warning(
333
- f"Skipping non-fixed joint: {joint.get('name')}"
334
- )
335
  continue
336
 
337
  parent_link = joint.find("parent").get("link")
338
  child_link = joint.find("child").get("link")
339
  origin = joint.find("origin")
340
-
341
  if parent_link not in body_dict or child_link not in body_dict:
342
  logger.warning(
343
  f"Parent or child link not found for joint: {joint.get('name')}"
344
  )
345
  continue
346
 
347
- # Move child body under parent body in MJCF hierarchy
348
  child_body = body_dict[child_link]
349
  mujoco_worldbody.remove(child_body)
350
  parent_body = body_dict[parent_link]
351
  parent_body.append(child_body)
352
-
353
- # Apply joint origin transformation to child body
354
  if origin is not None:
355
  xyz = origin.get("xyz", "0 0 0")
356
  rpy = origin.get("rpy", "0 0 0")
357
  child_body.set("pos", xyz)
358
- # Convert rpy to MJCF euler format (degrees)
359
- rpy_floats = list(map(float, rpy.split()))
360
- rotation = Rotation.from_euler(
361
- "xyz", rpy_floats, degrees=False
362
- )
363
- euler_deg = rotation.as_euler("xyz", degrees=True)
364
- child_body.set(
365
- "euler", f"{euler_deg[0]} {euler_deg[1]} {euler_deg[2]}"
366
- )
367
 
368
  tree = ET.ElementTree(mujoco_struct)
369
  ET.indent(tree, space=" ", level=0)
@@ -374,11 +481,15 @@ class URDFtoMJCFConverter(MeshtoMJCFConverter):
374
 
375
 
376
  class MeshtoUSDConverter(AssetConverterBase):
377
- """Convert Mesh file from URDF into USD format."""
 
 
 
378
 
379
  DEFAULT_BIND_APIS = [
380
  "MaterialBindingAPI",
381
  "PhysicsMeshCollisionAPI",
 
382
  "PhysicsCollisionAPI",
383
  "PhysxCollisionAPI",
384
  "PhysicsMassAPI",
@@ -393,41 +504,65 @@ class MeshtoUSDConverter(AssetConverterBase):
393
  simulation_app=None,
394
  **kwargs,
395
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  self.usd_parms = dict(
397
  force_usd_conversion=force_usd_conversion,
398
  make_instanceable=make_instanceable,
399
  **kwargs,
400
  )
401
- if simulation_app is not None:
402
- self.simulation_app = simulation_app
403
 
404
  def __enter__(self):
 
405
  from isaaclab.app import AppLauncher
406
 
407
  if not hasattr(self, "simulation_app"):
408
- launch_args = dict(
409
- headless=True,
410
- no_splash=True,
411
- fast_shutdown=True,
412
- disable_gpu=True,
413
- )
 
 
 
414
  self.app_launcher = AppLauncher(launch_args)
415
  self.simulation_app = self.app_launcher.app
416
 
417
  return self
418
 
419
  def __exit__(self, exc_type, exc_val, exc_tb):
 
420
  # Close the simulation app if it was created here
421
- if hasattr(self, "app_launcher"):
422
- self.simulation_app.close()
423
-
424
  if exc_val is not None:
425
  logger.error(f"Exception occurred: {exc_val}.")
426
 
 
 
 
427
  return False
428
 
429
  def convert(self, urdf_path: str, output_file: str):
430
- """Convert a URDF file to USD and post-process collision meshes."""
 
 
 
 
 
431
  from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
432
  from pxr import PhysxSchema, Sdf, Usd, UsdShade
433
 
@@ -449,10 +584,13 @@ class MeshtoUSDConverter(AssetConverterBase):
449
  )
450
  urdf_converter = MeshConverter(cfg)
451
  usd_path = urdf_converter.usd_path
 
452
 
453
  stage = Usd.Stage.Open(usd_path)
454
  layer = stage.GetRootLayer()
455
  with Usd.EditContext(stage, layer):
 
 
456
  for prim in stage.Traverse():
457
  # Change texture path to relative path.
458
  if prim.GetName() == "material_0":
@@ -465,11 +603,9 @@ class MeshtoUSDConverter(AssetConverterBase):
465
 
466
  # Add convex decomposition collision and set ShrinkWrap.
467
  elif prim.GetName() == "mesh":
468
- approx_attr = prim.GetAttribute("physics:approximation")
469
- if not approx_attr:
470
- approx_attr = prim.CreateAttribute(
471
- "physics:approximation", Sdf.ValueTypeNames.Token
472
- )
473
  approx_attr.Set("convexDecomposition")
474
 
475
  physx_conv_api = (
@@ -477,6 +613,15 @@ class MeshtoUSDConverter(AssetConverterBase):
477
  prim
478
  )
479
  )
 
 
 
 
 
 
 
 
 
480
  physx_conv_api.GetShrinkWrapAttr().Set(True)
481
 
482
  api_schemas = prim.GetMetadata("apiSchemas")
@@ -495,15 +640,105 @@ class MeshtoUSDConverter(AssetConverterBase):
495
  logger.info(f"Successfully converted {urdf_path} → {usd_path}")
496
 
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  class URDFtoUSDConverter(MeshtoUSDConverter):
499
- """Convert URDF files into USD format.
500
 
501
  Args:
502
- fix_base (bool): Whether to fix the base link.
503
- merge_fixed_joints (bool): Whether to merge fixed joints.
504
- make_instanceable (bool): Whether to make prims instanceable.
505
- force_usd_conversion (bool): Force conversion to USD.
506
- collision_from_visuals (bool): Generate collisions from visuals if not provided.
 
 
 
 
507
  """
508
 
509
  def __init__(
@@ -518,6 +753,19 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
518
  simulation_app=None,
519
  **kwargs,
520
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  self.usd_parms = dict(
522
  fix_base=fix_base,
523
  merge_fixed_joints=merge_fixed_joints,
@@ -532,7 +780,12 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
532
  self.simulation_app = simulation_app
533
 
534
  def convert(self, urdf_path: str, output_file: str):
535
- """Convert a URDF file to USD and post-process collision meshes."""
 
 
 
 
 
536
  from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
537
  from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
538
 
@@ -551,11 +804,9 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
551
  with Usd.EditContext(stage, layer):
552
  for prim in stage.Traverse():
553
  if prim.GetName() == "collisions":
554
- approx_attr = prim.GetAttribute("physics:approximation")
555
- if not approx_attr:
556
- approx_attr = prim.CreateAttribute(
557
- "physics:approximation", Sdf.ValueTypeNames.Token
558
- )
559
  approx_attr.Set("convexDecomposition")
560
 
561
  physx_conv_api = (
@@ -563,6 +814,9 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
563
  prim
564
  )
565
  )
 
 
 
566
  physx_conv_api.GetShrinkWrapAttr().Set(True)
567
 
568
  api_schemas = prim.GetMetadata("apiSchemas")
@@ -593,19 +847,44 @@ class URDFtoUSDConverter(MeshtoUSDConverter):
593
 
594
 
595
  class AssetConverterFactory:
596
- """Factory class for creating asset converters based on target and source types."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
 
598
  @staticmethod
599
  def create(
600
  target_type: AssetType, source_type: AssetType = "urdf", **kwargs
601
  ) -> AssetConverterBase:
602
- """Create an asset converter instance based on target and source types."""
603
- if target_type == AssetType.MJCF and source_type == AssetType.URDF:
 
 
 
 
 
 
 
 
 
604
  converter = MeshtoMJCFConverter(**kwargs)
605
- elif target_type == AssetType.USD and source_type == AssetType.URDF:
606
- converter = URDFtoUSDConverter(**kwargs)
607
  elif target_type == AssetType.USD and source_type == AssetType.MESH:
608
  converter = MeshtoUSDConverter(**kwargs)
 
 
609
  else:
610
  raise ValueError(
611
  f"Unsupported converter type: {source_type} -> {target_type}."
@@ -615,34 +894,48 @@ class AssetConverterFactory:
615
 
616
 
617
  if __name__ == "__main__":
618
- # # target_asset_type = AssetType.MJCF
619
  # target_asset_type = AssetType.USD
620
 
621
- # urdf_paths = [
622
- # "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
623
- # ]
624
-
625
- # if target_asset_type == AssetType.MJCF:
626
- # output_files = [
627
- # "outputs/embodiedgen_assets/demo_assets/remote_control/mjcf/remote_control.mjcf",
628
- # ]
629
- # asset_converter = AssetConverterFactory.create(
630
- # target_type=AssetType.MJCF,
631
- # source_type=AssetType.URDF,
632
- # )
633
 
634
- # elif target_asset_type == AssetType.USD:
635
- # output_files = [
636
- # "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
637
- # ]
638
- # asset_converter = AssetConverterFactory.create(
639
- # target_type=AssetType.USD,
640
- # source_type=AssetType.MESH,
641
- # )
642
 
643
- # with asset_converter:
644
- # for urdf_path, output_file in zip(urdf_paths, output_files):
645
- # asset_converter.convert(urdf_path, output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
  # urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
648
  # output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
@@ -656,8 +949,21 @@ if __name__ == "__main__":
656
  # with asset_converter:
657
  # asset_converter.convert(urdf_path, output_file)
658
 
659
- urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_simple_solve_nos_i_urdf/export_scene/scene.urdf"
660
- output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_simple_solve_nos_i_urdf/mjcf/scene.urdf"
661
- asset_converter = URDFtoMJCFConverter()
662
- with asset_converter:
663
- asset_converter.convert(urdf_path, output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  import xml.etree.ElementTree as ET
6
  from abc import ABC, abstractmethod
 
7
  from glob import glob
8
+ from shutil import copy, copytree, rmtree
9
 
10
  import trimesh
11
  from scipy.spatial.transform import Rotation
12
+ from embodied_gen.utils.enum import AssetType
13
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
 
17
 
18
  __all__ = [
19
  "AssetConverterFactory",
 
20
  "MeshtoMJCFConverter",
21
  "MeshtoUSDConverter",
22
  "URDFtoUSDConverter",
23
+ "cvt_embodiedgen_asset_to_anysim",
24
+ "PhysicsUSDAdder",
25
  ]
26
 
27
 
28
+ def cvt_embodiedgen_asset_to_anysim(
29
+ urdf_files: list[str],
30
+ target_dirs: list[str],
31
+ target_type: AssetType,
32
+ source_type: AssetType,
33
+ overwrite: bool = False,
34
+ **kwargs,
35
+ ) -> dict[str, str]:
36
+ """Convert URDF files generated by EmbodiedGen into formats required by simulators.
37
+
38
+ Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet.
39
+ Converting to the `USD` format requires `isaacsim` to be installed.
40
+
41
+ Example:
42
+ ```py
43
+ from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim
44
+ from embodied_gen.utils.enum import AssetType
45
+
46
+ dst_asset_path = cvt_embodiedgen_asset_to_anysim(
47
+ urdf_files=[
48
+ "path1_to_embodiedgen_asset/asset.urdf",
49
+ "path2_to_embodiedgen_asset/asset.urdf",
50
+ ],
51
+ target_dirs=[
52
+ "path1_to_target_dir/asset.usd",
53
+ "path2_to_target_dir/asset.usd",
54
+ ],
55
+ target_type=AssetType.USD,
56
+ source_type=AssetType.MESH,
57
+ )
58
+ ```
59
+
60
+ Args:
61
+ urdf_files (list[str]): List of URDF file paths.
62
+ target_dirs (list[str]): List of target directories.
63
+ target_type (AssetType): Target asset type.
64
+ source_type (AssetType): Source asset type.
65
+ overwrite (bool, optional): Overwrite existing files.
66
+ **kwargs: Additional converter arguments.
67
+
68
+ Returns:
69
+ dict[str, str]: Mapping from URDF file to converted asset file.
70
+ """
71
+
72
+ if isinstance(urdf_files, str):
73
+ urdf_files = [urdf_files]
74
+ if isinstance(target_dirs, str):
75
+ urdf_files = [target_dirs]
76
+
77
+ # If the target type is URDF, no conversion is needed.
78
+ if target_type == AssetType.URDF:
79
+ return {key: key for key in urdf_files}
80
+
81
+ asset_converter = AssetConverterFactory.create(
82
+ target_type=target_type,
83
+ source_type=source_type,
84
+ **kwargs,
85
+ )
86
+ asset_paths = dict()
87
+
88
+ with asset_converter:
89
+ for urdf_file, target_dir in zip(urdf_files, target_dirs):
90
+ filename = os.path.basename(urdf_file).replace(".urdf", "")
91
+ if target_type == AssetType.MJCF:
92
+ target_file = f"{target_dir}/{filename}.xml"
93
+ elif target_type == AssetType.USD:
94
+ target_file = f"{target_dir}/{filename}.usd"
95
+ else:
96
+ raise NotImplementedError(
97
+ f"Target type {target_type} not supported."
98
+ )
99
+ if not os.path.exists(target_file) or overwrite:
100
+ asset_converter.convert(urdf_file, target_file)
101
+
102
+ asset_paths[urdf_file] = target_file
103
 
104
+ return asset_paths
 
 
 
105
 
106
 
107
  class AssetConverterBase(ABC):
108
+ """Abstract base class for asset converters.
109
+
110
+ Provides context management and mesh transformation utilities.
111
+ """
112
 
113
  @abstractmethod
114
  def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
115
+ """Convert an asset file.
116
+
117
+ Args:
118
+ urdf_path (str): Path to input URDF file.
119
+ output_path (str): Path to output file.
120
+ **kwargs: Additional arguments.
121
+
122
+ Returns:
123
+ str: Path to converted asset.
124
+ """
125
  pass
126
 
127
  def transform_mesh(
128
  self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
129
  ) -> None:
130
+ """Apply transform to mesh based on URDF origin element.
131
+
132
+ Args:
133
+ input_mesh (str): Path to input mesh.
134
+ output_mesh (str): Path to output mesh.
135
+ mesh_origin (ET.Element): Origin element from URDF.
136
+ """
137
+ mesh = trimesh.load(input_mesh, group_material=False)
138
  rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
139
  rotation = Rotation.from_euler("xyz", rpy, degrees=False)
140
  offset = list(map(float, mesh_origin.get("xyz").split(" ")))
 
 
141
  os.makedirs(os.path.dirname(output_mesh), exist_ok=True)
142
+
143
+ if isinstance(mesh, trimesh.Scene):
144
+ combined = trimesh.Scene()
145
+ for mesh_part in mesh.geometry.values():
146
+ mesh_part.vertices = (
147
+ mesh_part.vertices @ rotation.as_matrix().T
148
+ ) + offset
149
+ combined.add_geometry(mesh_part)
150
+ _ = combined.export(output_mesh)
151
+ else:
152
+ mesh.vertices = (mesh.vertices @ rotation.as_matrix().T) + offset
153
+ _ = mesh.export(output_mesh)
154
 
155
  return
156
 
157
  def __enter__(self):
158
+ """Context manager entry."""
159
  return self
160
 
161
  def __exit__(self, exc_type, exc_val, exc_tb):
162
+ """Context manager exit."""
163
  return False
164
 
165
 
166
  class MeshtoMJCFConverter(AssetConverterBase):
167
+ """Converts mesh-based URDF files to MJCF format.
168
+
169
+ Handles geometry, materials, and asset copying.
170
+ """
171
 
172
  def __init__(
173
  self,
 
176
  self.kwargs = kwargs
177
 
178
  def _copy_asset_file(self, src: str, dst: str) -> None:
179
+ """Copies asset file if not already present.
180
+
181
+ Args:
182
+ src (str): Source file path.
183
+ dst (str): Destination file path.
184
+ """
185
  if os.path.exists(dst):
186
  return
187
  os.makedirs(os.path.dirname(dst), exist_ok=True)
 
199
  material: ET.Element | None = None,
200
  is_collision: bool = False,
201
  ) -> None:
202
+ """Adds geometry to MJCF body from URDF link.
203
+
204
+ Args:
205
+ mujoco_element (ET.Element): MJCF asset element.
206
+ link (ET.Element): URDF link element.
207
+ body (ET.Element): MJCF body element.
208
+ tag (str): Tag name ("visual" or "collision").
209
+ input_dir (str): Input directory.
210
+ output_dir (str): Output directory.
211
+ mesh_name (str): Mesh name.
212
+ material (ET.Element, optional): Material element.
213
+ is_collision (bool, optional): If True, treat as collision geometry.
214
+ """
215
  element = link.find(tag)
216
  geometry = element.find("geometry")
217
  mesh = geometry.find("mesh")
218
  filename = mesh.get("filename")
219
  scale = mesh.get("scale", "1.0 1.0 1.0")
 
 
 
 
 
 
 
 
 
 
 
 
220
  input_mesh = f"{input_dir}/{filename}"
221
  output_mesh = f"{output_dir}/{filename}"
222
+ self._copy_asset_file(input_mesh, output_mesh)
223
+
224
  mesh_origin = element.find("origin")
225
  if mesh_origin is not None:
226
  self.transform_mesh(input_mesh, output_mesh, mesh_origin)
227
 
 
 
 
228
  if is_collision:
229
+ mesh_parts = trimesh.load(
230
+ output_mesh, group_material=False, force="scene"
231
+ )
232
+ mesh_parts = mesh_parts.geometry.values()
233
+ else:
234
+ mesh_parts = [trimesh.load(output_mesh, force="mesh")]
235
+ for idx, mesh_part in enumerate(mesh_parts):
236
+ if is_collision:
237
+ idx_mesh_name = f"{mesh_name}_{idx}"
238
+ base, ext = os.path.splitext(filename)
239
+ idx_filename = f"{base}_{idx}{ext}"
240
+ base_outdir = os.path.dirname(output_mesh)
241
+ mesh_part.export(os.path.join(base_outdir, '..', idx_filename))
242
+ geom_attrs = {
243
+ "contype": "1",
244
+ "conaffinity": "1",
245
+ "rgba": "1 1 1 0",
246
+ }
247
+ else:
248
+ idx_mesh_name, idx_filename = mesh_name, filename
249
+ geom_attrs = {"contype": "0", "conaffinity": "0"}
250
+
251
+ ET.SubElement(
252
+ mujoco_element,
253
+ "mesh",
254
+ name=idx_mesh_name,
255
+ file=idx_filename,
256
+ scale=scale,
257
+ )
258
+ geom = ET.SubElement(body, "geom", type="mesh", mesh=idx_mesh_name)
259
+ geom.attrib.update(geom_attrs)
260
+ if material is not None:
261
+ geom.set("material", material.get("name"))
262
 
263
  def add_materials(
264
  self,
 
270
  name: str,
271
  reflectance: float = 0.2,
272
  ) -> ET.Element:
273
+ """Adds materials to MJCF asset from URDF link.
274
+
275
+ Args:
276
+ mujoco_element (ET.Element): MJCF asset element.
277
+ link (ET.Element): URDF link element.
278
+ tag (str): Tag name.
279
+ input_dir (str): Input directory.
280
+ output_dir (str): Output directory.
281
+ name (str): Material name.
282
+ reflectance (float, optional): Reflectance value.
283
+
284
+ Returns:
285
+ ET.Element: Material element.
286
+ """
287
  element = link.find(tag)
288
  geometry = element.find("geometry")
289
  mesh = geometry.find("mesh")
290
  filename = mesh.get("filename")
291
  dirname = os.path.dirname(filename)
292
+ material = None
 
 
 
 
 
 
 
 
293
  for path in glob(f"{input_dir}/{dirname}/*.png"):
294
  file_name = os.path.basename(path)
295
+ if "keep_materials" in self.kwargs:
296
+ find_flag = False
297
+ for keep_key in self.kwargs["keep_materials"]:
298
+ if keep_key in file_name.lower():
299
+ find_flag = True
300
+ if find_flag is False:
301
+ continue
302
+
303
  self._copy_asset_file(
304
  path,
305
  f"{output_dir}/{dirname}/{file_name}",
306
  )
307
+ texture_name = f"texture_{name}_{os.path.splitext(file_name)[0]}"
308
+ material = ET.SubElement(
309
+ mujoco_element,
310
+ "material",
311
+ name=f"material_{name}",
312
+ texture=texture_name,
313
+ reflectance=str(reflectance),
314
+ )
315
  ET.SubElement(
316
  mujoco_element,
317
  "texture",
318
+ name=texture_name,
319
  type="2d",
320
  file=f"{dirname}/{file_name}",
321
  )
 
323
  return material
324
 
325
  def convert(self, urdf_path: str, mjcf_path: str):
326
+ """Converts a URDF file to MJCF format.
327
+
328
+ Args:
329
+ urdf_path (str): Path to URDF file.
330
+ mjcf_path (str): Path to output MJCF file.
331
+ """
332
  tree = ET.parse(urdf_path)
333
  root = tree.getroot()
334
 
 
352
  output_dir,
353
  name=str(idx),
354
  )
355
+ joint = ET.SubElement(body, "joint", attrib={"type": "free"})
356
  self.add_geometry(
357
  mujoco_asset,
358
  link,
 
382
 
383
 
384
  class URDFtoMJCFConverter(MeshtoMJCFConverter):
385
+ """Converts URDF files with joints to MJCF format, handling joint transformations.
386
 
387
+ Handles fixed joints and hierarchical body structure.
388
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
+ def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
391
+ """Converts a URDF file with joints to MJCF format.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ Args:
394
+ urdf_path (str): Path to URDF file.
395
+ mjcf_path (str): Path to output MJCF file.
396
+ **kwargs: Additional arguments.
397
 
398
+ Returns:
399
+ str: Path to converted MJCF file.
400
+ """
401
  tree = ET.parse(urdf_path)
402
  root = tree.getroot()
403
 
 
410
  output_dir = os.path.dirname(mjcf_path)
411
  os.makedirs(output_dir, exist_ok=True)
412
 
 
413
  body_dict = {}
 
 
414
  for idx, link in enumerate(root.findall("link")):
415
  link_name = link.get("name", f"unnamed_link_{idx}")
416
  body = ET.SubElement(mujoco_worldbody, "body", name=link_name)
417
  body_dict[link_name] = body
418
+ if link.find("visual") is not None:
 
 
 
419
  material = self.add_materials(
420
  mujoco_asset,
421
  link,
 
434
  f"visual_mesh_{idx}",
435
  material,
436
  )
437
+ if link.find("collision") is not None:
 
 
438
  self.add_geometry(
439
  mujoco_asset,
440
  link,
 
450
  for joint in root.findall("joint"):
451
  joint_type = joint.get("type")
452
  if joint_type != "fixed":
453
+ logger.warning("Only support fixed joints in conversion now.")
 
 
454
  continue
455
 
456
  parent_link = joint.find("parent").get("link")
457
  child_link = joint.find("child").get("link")
458
  origin = joint.find("origin")
 
459
  if parent_link not in body_dict or child_link not in body_dict:
460
  logger.warning(
461
  f"Parent or child link not found for joint: {joint.get('name')}"
462
  )
463
  continue
464
 
 
465
  child_body = body_dict[child_link]
466
  mujoco_worldbody.remove(child_body)
467
  parent_body = body_dict[parent_link]
468
  parent_body.append(child_body)
 
 
469
  if origin is not None:
470
  xyz = origin.get("xyz", "0 0 0")
471
  rpy = origin.get("rpy", "0 0 0")
472
  child_body.set("pos", xyz)
473
+ child_body.set("euler", rpy)
 
 
 
 
 
 
 
 
474
 
475
  tree = ET.ElementTree(mujoco_struct)
476
  ET.indent(tree, space=" ", level=0)
 
481
 
482
 
483
  class MeshtoUSDConverter(AssetConverterBase):
484
+ """Converts mesh-based URDF files to USD format.
485
+
486
+ Adds physics APIs and post-processes collision meshes.
487
+ """
488
 
489
  DEFAULT_BIND_APIS = [
490
  "MaterialBindingAPI",
491
  "PhysicsMeshCollisionAPI",
492
+ "PhysxConvexDecompositionCollisionAPI",
493
  "PhysicsCollisionAPI",
494
  "PhysxCollisionAPI",
495
  "PhysicsMassAPI",
 
504
  simulation_app=None,
505
  **kwargs,
506
  ):
507
+ """Initializes the converter.
508
+
509
+ Args:
510
+ force_usd_conversion (bool, optional): Force USD conversion.
511
+ make_instanceable (bool, optional): Make prims instanceable.
512
+ simulation_app (optional): Simulation app instance.
513
+ **kwargs: Additional arguments.
514
+ """
515
+ if simulation_app is not None:
516
+ self.simulation_app = simulation_app
517
+
518
+ self.exit_close = kwargs.pop("exit_close", True)
519
+ self.physx_max_convex_hulls = kwargs.pop("physx_max_convex_hulls", 32)
520
+ self.physx_max_vertices = kwargs.pop("physx_max_vertices", 16)
521
+ self.physx_max_voxel_res = kwargs.pop("physx_max_voxel_res", 10000)
522
+
523
  self.usd_parms = dict(
524
  force_usd_conversion=force_usd_conversion,
525
  make_instanceable=make_instanceable,
526
  **kwargs,
527
  )
 
 
528
 
529
  def __enter__(self):
530
+ """Context manager entry, launches simulation app if needed."""
531
  from isaaclab.app import AppLauncher
532
 
533
  if not hasattr(self, "simulation_app"):
534
+ if "launch_args" not in self.usd_parms:
535
+ launch_args = dict(
536
+ headless=True,
537
+ no_splash=True,
538
+ fast_shutdown=True,
539
+ disable_gpu=True,
540
+ )
541
+ else:
542
+ launch_args = self.usd_parms.pop("launch_args")
543
  self.app_launcher = AppLauncher(launch_args)
544
  self.simulation_app = self.app_launcher.app
545
 
546
  return self
547
 
548
  def __exit__(self, exc_type, exc_val, exc_tb):
549
+ """Context manager exit, closes simulation app if created."""
550
  # Close the simulation app if it was created here
 
 
 
551
  if exc_val is not None:
552
  logger.error(f"Exception occurred: {exc_val}.")
553
 
554
+ if hasattr(self, "app_launcher") and self.exit_close:
555
+ self.simulation_app.close()
556
+
557
  return False
558
 
559
  def convert(self, urdf_path: str, output_file: str):
560
+ """Converts a URDF file to USD and post-processes collision meshes.
561
+
562
+ Args:
563
+ urdf_path (str): Path to URDF file.
564
+ output_file (str): Path to output USD file.
565
+ """
566
  from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
567
  from pxr import PhysxSchema, Sdf, Usd, UsdShade
568
 
 
584
  )
585
  urdf_converter = MeshConverter(cfg)
586
  usd_path = urdf_converter.usd_path
587
+ rmtree(os.path.dirname(output_mesh))
588
 
589
  stage = Usd.Stage.Open(usd_path)
590
  layer = stage.GetRootLayer()
591
  with Usd.EditContext(stage, layer):
592
+ base_prim = stage.GetPseudoRoot().GetChildren()[0]
593
+ base_prim.SetMetadata("kind", "component")
594
  for prim in stage.Traverse():
595
  # Change texture path to relative path.
596
  if prim.GetName() == "material_0":
 
603
 
604
  # Add convex decomposition collision and set ShrinkWrap.
605
  elif prim.GetName() == "mesh":
606
+ approx_attr = prim.CreateAttribute(
607
+ "physics:approximation", Sdf.ValueTypeNames.Token
608
+ )
 
 
609
  approx_attr.Set("convexDecomposition")
610
 
611
  physx_conv_api = (
 
613
  prim
614
  )
615
  )
616
+ physx_conv_api.GetMaxConvexHullsAttr().Set(
617
+ self.physx_max_convex_hulls
618
+ )
619
+ physx_conv_api.GetHullVertexLimitAttr().Set(
620
+ self.physx_max_vertices
621
+ )
622
+ physx_conv_api.GetVoxelResolutionAttr().Set(
623
+ self.physx_max_voxel_res
624
+ )
625
  physx_conv_api.GetShrinkWrapAttr().Set(True)
626
 
627
  api_schemas = prim.GetMetadata("apiSchemas")
 
640
  logger.info(f"Successfully converted {urdf_path} → {usd_path}")
641
 
642
 
643
+ class PhysicsUSDAdder(MeshtoUSDConverter):
644
+ """Adds physics APIs and collision properties to USD assets.
645
+
646
+ Useful for post-processing USD files for simulation.
647
+ """
648
+
649
+ DEFAULT_BIND_APIS = [
650
+ "MaterialBindingAPI",
651
+ "PhysicsMeshCollisionAPI",
652
+ "PhysxConvexDecompositionCollisionAPI",
653
+ "PhysicsCollisionAPI",
654
+ "PhysxCollisionAPI",
655
+ "PhysicsRigidBodyAPI",
656
+ ]
657
+
658
+ def convert(self, usd_path: str, output_file: str = None):
659
+ """Adds physics APIs and collision properties to a USD file.
660
+
661
+ Args:
662
+ usd_path (str): Path to input USD file.
663
+ output_file (str, optional): Path to output USD file.
664
+ """
665
+ from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics
666
+
667
+ if output_file is None:
668
+ output_file = usd_path
669
+ else:
670
+ dst_dir = os.path.dirname(output_file)
671
+ src_dir = os.path.dirname(usd_path)
672
+ copytree(src_dir, dst_dir, dirs_exist_ok=True)
673
+
674
+ stage = Usd.Stage.Open(output_file)
675
+ layer = stage.GetRootLayer()
676
+ with Usd.EditContext(stage, layer):
677
+ for prim in stage.Traverse():
678
+ if prim.IsA(UsdGeom.Xform):
679
+ for child in prim.GetChildren():
680
+ if not child.IsA(UsdGeom.Mesh):
681
+ continue
682
+
683
+ # Skip the lightfactory in Infinigen
684
+ if "lightfactory" in prim.GetName().lower():
685
+ continue
686
+
687
+ approx_attr = prim.CreateAttribute(
688
+ "physics:approximation", Sdf.ValueTypeNames.Token
689
+ )
690
+ approx_attr.Set("convexDecomposition")
691
+
692
+ physx_conv_api = PhysxSchema.PhysxConvexDecompositionCollisionAPI.Apply(
693
+ prim
694
+ )
695
+ physx_conv_api.GetMaxConvexHullsAttr().Set(
696
+ self.physx_max_convex_hulls
697
+ )
698
+ physx_conv_api.GetHullVertexLimitAttr().Set(
699
+ self.physx_max_vertices
700
+ )
701
+ physx_conv_api.GetVoxelResolutionAttr().Set(
702
+ self.physx_max_voxel_res
703
+ )
704
+ physx_conv_api.GetShrinkWrapAttr().Set(True)
705
+
706
+ rigid_body_api = UsdPhysics.RigidBodyAPI.Apply(prim)
707
+ rigid_body_api.CreateKinematicEnabledAttr().Set(True)
708
+ if prim.GetAttribute("physics:mass"):
709
+ prim.RemoveProperty("physics:mass")
710
+ if prim.GetAttribute("physics:velocity"):
711
+ prim.RemoveProperty("physics:velocity")
712
+
713
+ api_schemas = prim.GetMetadata("apiSchemas")
714
+ if api_schemas is None:
715
+ api_schemas = Sdf.TokenListOp()
716
+
717
+ api_list = list(api_schemas.GetAddedOrExplicitItems())
718
+ for api in self.DEFAULT_BIND_APIS:
719
+ if api not in api_list:
720
+ api_list.append(api)
721
+
722
+ api_schemas.appendedItems = api_list
723
+ prim.SetMetadata("apiSchemas", api_schemas)
724
+
725
+ layer.Save()
726
+ logger.info(f"Successfully converted {usd_path} to {output_file}")
727
+
728
+
729
  class URDFtoUSDConverter(MeshtoUSDConverter):
730
+ """Converts URDF files to USD format.
731
 
732
  Args:
733
+ fix_base (bool, optional): Fix the base link.
734
+ merge_fixed_joints (bool, optional): Merge fixed joints.
735
+ make_instanceable (bool, optional): Make prims instanceable.
736
+ force_usd_conversion (bool, optional): Force conversion to USD.
737
+ collision_from_visuals (bool, optional): Generate collisions from visuals.
738
+ joint_drive (optional): Joint drive configuration.
739
+ rotate_wxyz (tuple[float], optional): Quaternion for rotation.
740
+ simulation_app (optional): Simulation app instance.
741
+ **kwargs: Additional arguments.
742
  """
743
 
744
  def __init__(
 
753
  simulation_app=None,
754
  **kwargs,
755
  ):
756
+ """Initializes the converter.
757
+
758
+ Args:
759
+ fix_base (bool, optional): Fix the base link.
760
+ merge_fixed_joints (bool, optional): Merge fixed joints.
761
+ make_instanceable (bool, optional): Make prims instanceable.
762
+ force_usd_conversion (bool, optional): Force conversion to USD.
763
+ collision_from_visuals (bool, optional): Generate collisions from visuals.
764
+ joint_drive (optional): Joint drive configuration.
765
+ rotate_wxyz (tuple[float], optional): Quaternion for rotation.
766
+ simulation_app (optional): Simulation app instance.
767
+ **kwargs: Additional arguments.
768
+ """
769
  self.usd_parms = dict(
770
  fix_base=fix_base,
771
  merge_fixed_joints=merge_fixed_joints,
 
780
  self.simulation_app = simulation_app
781
 
782
  def convert(self, urdf_path: str, output_file: str):
783
+ """Converts a URDF file to USD and post-processes collision meshes.
784
+
785
+ Args:
786
+ urdf_path (str): Path to URDF file.
787
+ output_file (str): Path to output USD file.
788
+ """
789
  from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
790
  from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
791
 
 
804
  with Usd.EditContext(stage, layer):
805
  for prim in stage.Traverse():
806
  if prim.GetName() == "collisions":
807
+ approx_attr = prim.CreateAttribute(
808
+ "physics:approximation", Sdf.ValueTypeNames.Token
809
+ )
 
 
810
  approx_attr.Set("convexDecomposition")
811
 
812
  physx_conv_api = (
 
814
  prim
815
  )
816
  )
817
+ physx_conv_api.GetMaxConvexHullsAttr().Set(32)
818
+ physx_conv_api.GetHullVertexLimitAttr().Set(16)
819
+ physx_conv_api.GetVoxelResolutionAttr().Set(10000)
820
  physx_conv_api.GetShrinkWrapAttr().Set(True)
821
 
822
  api_schemas = prim.GetMetadata("apiSchemas")
 
847
 
848
 
849
  class AssetConverterFactory:
850
+ """Factory for creating asset converters based on target and source types.
851
+
852
+ Example:
853
+ ```py
854
+ from embodied_gen.data.asset_converter import AssetConverterFactory
855
+ from embodied_gen.utils.enum import AssetType
856
+
857
+ converter = AssetConverterFactory.create(
858
+ target_type=AssetType.USD, source_type=AssetType.MESH
859
+ )
860
+ with converter:
861
+ for urdf_path, output_file in zip(urdf_paths, output_files):
862
+ converter.convert(urdf_path, output_file)
863
+ ```
864
+ """
865
 
866
  @staticmethod
867
  def create(
868
  target_type: AssetType, source_type: AssetType = "urdf", **kwargs
869
  ) -> AssetConverterBase:
870
+ """Creates an asset converter instance.
871
+
872
+ Args:
873
+ target_type (AssetType): Target asset type.
874
+ source_type (AssetType, optional): Source asset type.
875
+ **kwargs: Additional arguments.
876
+
877
+ Returns:
878
+ AssetConverterBase: Converter instance.
879
+ """
880
+ if target_type == AssetType.MJCF and source_type == AssetType.MESH:
881
  converter = MeshtoMJCFConverter(**kwargs)
882
+ elif target_type == AssetType.MJCF and source_type == AssetType.URDF:
883
+ converter = URDFtoMJCFConverter(**kwargs)
884
  elif target_type == AssetType.USD and source_type == AssetType.MESH:
885
  converter = MeshtoUSDConverter(**kwargs)
886
+ elif target_type == AssetType.USD and source_type == AssetType.URDF:
887
+ converter = URDFtoUSDConverter(**kwargs)
888
  else:
889
  raise ValueError(
890
  f"Unsupported converter type: {source_type} -> {target_type}."
 
894
 
895
 
896
  if __name__ == "__main__":
897
+ target_asset_type = AssetType.MJCF
898
  # target_asset_type = AssetType.USD
899
 
900
+ urdf_paths = [
901
+ 'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf',
902
+ 'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf',
903
+ 'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf',
904
+ 'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf',
905
+ 'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf',
906
+ "outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf",
907
+ 'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf',
908
+ 'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf',
909
+ ]
 
 
910
 
911
+ if target_asset_type == AssetType.MJCF:
912
+ output_files = [
913
+ "outputs/embodiedgen_assets/demo_assets/demo_assets/remote_control/mjcf/remote_control.xml",
914
+ ]
915
+ asset_converter = AssetConverterFactory.create(
916
+ target_type=AssetType.MJCF,
917
+ source_type=AssetType.MESH,
918
+ )
919
 
920
+ elif target_asset_type == AssetType.USD:
921
+ output_files = [
922
+ 'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd',
923
+ 'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd',
924
+ 'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd',
925
+ 'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd',
926
+ 'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd',
927
+ "outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd",
928
+ 'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd',
929
+ 'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd',
930
+ ]
931
+ asset_converter = AssetConverterFactory.create(
932
+ target_type=AssetType.USD,
933
+ source_type=AssetType.MESH,
934
+ )
935
+
936
+ with asset_converter:
937
+ for urdf_path, output_file in zip(urdf_paths, output_files):
938
+ asset_converter.convert(urdf_path, output_file)
939
 
940
  # urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
941
  # output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
 
949
  # with asset_converter:
950
  # asset_converter.convert(urdf_path, output_file)
951
 
952
+ # # Convert infinigen urdf to mjcf
953
+ # urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf"
954
+ # output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml"
955
+ # asset_converter = AssetConverterFactory.create(
956
+ # target_type=AssetType.MJCF,
957
+ # source_type=AssetType.URDF,
958
+ # keep_materials=["diffuse"],
959
+ # )
960
+ # with asset_converter:
961
+ # asset_converter.convert(urdf_path, output_file)
962
+
963
+ # # Convert infinigen usdc to physics usdc
964
+ # converter = PhysicsUSDAdder()
965
+ # with converter:
966
+ # converter.convert(
967
+ # usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc",
968
+ # output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc",
969
+ # )
embodied_gen/data/backproject.py CHANGED
@@ -34,6 +34,7 @@ from embodied_gen.data.utils import (
34
  CameraSetting,
35
  get_images_from_grid,
36
  init_kal_camera,
 
37
  normalize_vertices_array,
38
  post_process_texture,
39
  save_mesh_with_mtl,
@@ -306,28 +307,6 @@ class TextureBaker(object):
306
  raise ValueError(f"Unknown mode: {mode}")
307
 
308
 
309
- def kaolin_to_opencv_view(raw_matrix):
310
- R_orig = raw_matrix[:, :3, :3]
311
- t_orig = raw_matrix[:, :3, 3]
312
-
313
- R_target = torch.zeros_like(R_orig)
314
- R_target[:, :, 0] = R_orig[:, :, 2]
315
- R_target[:, :, 1] = R_orig[:, :, 0]
316
- R_target[:, :, 2] = R_orig[:, :, 1]
317
-
318
- t_target = t_orig
319
-
320
- target_matrix = (
321
- torch.eye(4, device=raw_matrix.device)
322
- .unsqueeze(0)
323
- .repeat(raw_matrix.size(0), 1, 1)
324
- )
325
- target_matrix[:, :3, :3] = R_target
326
- target_matrix[:, :3, 3] = t_target
327
-
328
- return target_matrix
329
-
330
-
331
  def parse_args():
332
  parser = argparse.ArgumentParser(description="Render settings")
333
 
 
34
  CameraSetting,
35
  get_images_from_grid,
36
  init_kal_camera,
37
+ kaolin_to_opencv_view,
38
  normalize_vertices_array,
39
  post_process_texture,
40
  save_mesh_with_mtl,
 
307
  raise ValueError(f"Unknown mode: {mode}")
308
 
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  def parse_args():
311
  parser = argparse.ArgumentParser(description="Render settings")
312
 
embodied_gen/data/backproject_v2.py CHANGED
@@ -58,7 +58,16 @@ __all__ = [
58
  def _transform_vertices(
59
  mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
60
  ) -> torch.Tensor:
61
- """Transform 3D vertices using a projection matrix."""
 
 
 
 
 
 
 
 
 
62
  t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
63
  if pos.size(-1) == 3:
64
  pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
@@ -71,7 +80,17 @@ def _transform_vertices(
71
  def _bilinear_interpolation_scattering(
72
  image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
73
  ) -> torch.Tensor:
74
- """Bilinear interpolation scattering for grid-based value accumulation."""
 
 
 
 
 
 
 
 
 
 
75
  device = values.device
76
  dtype = values.dtype
77
  C = values.shape[-1]
@@ -135,7 +154,18 @@ def _texture_inpaint_smooth(
135
  faces: np.ndarray,
136
  uv_map: np.ndarray,
137
  ) -> tuple[np.ndarray, np.ndarray]:
138
- """Perform texture inpainting using vertex-based color propagation."""
 
 
 
 
 
 
 
 
 
 
 
139
  image_h, image_w, C = texture.shape
140
  N = vertices.shape[0]
141
 
@@ -231,29 +261,41 @@ def _texture_inpaint_smooth(
231
  class TextureBacker:
232
  """Texture baking pipeline for multi-view projection and fusion.
233
 
234
- This class performs UV-based texture generation for a 3D mesh using
235
- multi-view color images, depth, and normal information. The pipeline
236
- includes mesh normalization and UV unwrapping, visibility-aware
237
- back-projection, confidence-weighted texture fusion, and inpainting
238
- of missing texture regions.
239
 
240
  Args:
241
- camera_params (CameraSetting): Camera intrinsics and extrinsics used
242
- for rendering each view.
243
- view_weights (list[float]): A list of weights for each view, used
244
- to blend confidence maps during texture fusion.
245
- render_wh (tuple[int, int], optional): Resolution (width, height) for
246
- intermediate rendering passes. Defaults to (2048, 2048).
247
- texture_wh (tuple[int, int], optional): Output texture resolution
248
- (width, height). Defaults to (2048, 2048).
249
- bake_angle_thresh (int, optional): Maximum angle (in degrees) between
250
- view direction and surface normal for projection to be considered valid.
251
- Defaults to 75.
252
- mask_thresh (float, optional): Threshold applied to visibility masks
253
- during rendering. Defaults to 0.5.
254
- smooth_texture (bool, optional): If True, apply post-processing (e.g.,
255
- blurring) to the final texture. Defaults to True.
256
- inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  """
258
 
259
  def __init__(
@@ -283,6 +325,12 @@ class TextureBacker:
283
  )
284
 
285
  def _lazy_init_render(self, camera_params, mask_thresh):
 
 
 
 
 
 
286
  if self.renderer is None:
287
  camera = init_kal_camera(camera_params)
288
  mv = camera.view_matrix() # (n 4 4) world2cam
@@ -301,6 +349,14 @@ class TextureBacker:
301
  )
302
 
303
  def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
 
 
 
 
 
 
 
 
304
  mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
305
  self.scale, self.center = scale, center
306
 
@@ -318,6 +374,16 @@ class TextureBacker:
318
  scale: float = None,
319
  center: np.ndarray = None,
320
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
321
  vertices = mesh.vertices.copy()
322
  faces = mesh.faces.copy()
323
  uv_map = mesh.visual.uv.copy()
@@ -331,6 +397,14 @@ class TextureBacker:
331
  return vertices, faces, uv_map
332
 
333
  def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
334
  depth_image_np = depth_image.cpu().numpy()
335
  depth_image_np = (depth_image_np * 255).astype(np.uint8)
336
  depth_edges = cv2.Canny(depth_image_np, 30, 80)
@@ -344,6 +418,16 @@ class TextureBacker:
344
  def compute_enhanced_viewnormal(
345
  self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
346
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
347
  rast, _ = self.renderer.compute_dr_raster(vertices, faces)
348
  rendered_view_normals = []
349
  for idx in range(len(mv_mtx)):
@@ -376,6 +460,18 @@ class TextureBacker:
376
  def back_project(
377
  self, image, vis_mask, depth, normal, uv
378
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
379
  image = np.array(image)
380
  image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
381
  if image.ndim == 2:
@@ -418,6 +514,17 @@ class TextureBacker:
418
  )
419
 
420
  def _scatter_texture(self, uv, data, mask):
 
 
 
 
 
 
 
 
 
 
 
421
  def __filter_data(data, mask):
422
  return data.view(-1, data.shape[-1])[mask]
423
 
@@ -432,6 +539,15 @@ class TextureBacker:
432
  def fast_bake_texture(
433
  self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
434
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
435
  channel = textures[0].shape[-1]
436
  texture_merge = torch.zeros(self.texture_wh + [channel]).to(
437
  self.device
@@ -451,6 +567,16 @@ class TextureBacker:
451
  def uv_inpaint(
452
  self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
453
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
454
  if self.inpaint_smooth:
455
  vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
456
  texture, mask = _texture_inpaint_smooth(
@@ -473,6 +599,15 @@ class TextureBacker:
473
  colors: list[Image.Image],
474
  mesh: trimesh.Trimesh,
475
  ) -> trimesh.Trimesh:
 
 
 
 
 
 
 
 
 
476
  self._lazy_init_render(self.camera_params, self.mask_thresh)
477
 
478
  vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
@@ -517,7 +652,7 @@ class TextureBacker:
517
  Args:
518
  colors (list[Image.Image]): List of input view images.
519
  mesh (trimesh.Trimesh): Input mesh to be textured.
520
- output_path (str): Path to save the output textured mesh (.obj or .glb).
521
 
522
  Returns:
523
  trimesh.Trimesh: The textured mesh with UV and texture image.
@@ -540,6 +675,11 @@ class TextureBacker:
540
 
541
 
542
  def parse_args():
 
 
 
 
 
543
  parser = argparse.ArgumentParser(description="Backproject texture")
544
  parser.add_argument(
545
  "--color_path",
@@ -636,6 +776,16 @@ def entrypoint(
636
  imagesr_model: ImageRealESRGAN = None,
637
  **kwargs,
638
  ) -> trimesh.Trimesh:
 
 
 
 
 
 
 
 
 
 
639
  args = parse_args()
640
  for k, v in kwargs.items():
641
  if hasattr(args, k) and v is not None:
 
58
  def _transform_vertices(
59
  mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
60
  ) -> torch.Tensor:
61
+ """Transforms 3D vertices using a projection matrix.
62
+
63
+ Args:
64
+ mtx (torch.Tensor): Projection matrix.
65
+ pos (torch.Tensor): Vertex positions.
66
+ keepdim (bool, optional): If True, keeps the batch dimension.
67
+
68
+ Returns:
69
+ torch.Tensor: Transformed vertices.
70
+ """
71
  t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
72
  if pos.size(-1) == 3:
73
  pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
 
80
  def _bilinear_interpolation_scattering(
81
  image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
82
  ) -> torch.Tensor:
83
+ """Performs bilinear interpolation scattering for grid-based value accumulation.
84
+
85
+ Args:
86
+ image_h (int): Image height.
87
+ image_w (int): Image width.
88
+ coords (torch.Tensor): Normalized coordinates.
89
+ values (torch.Tensor): Values to scatter.
90
+
91
+ Returns:
92
+ torch.Tensor: Interpolated grid.
93
+ """
94
  device = values.device
95
  dtype = values.dtype
96
  C = values.shape[-1]
 
154
  faces: np.ndarray,
155
  uv_map: np.ndarray,
156
  ) -> tuple[np.ndarray, np.ndarray]:
157
+ """Performs texture inpainting using vertex-based color propagation.
158
+
159
+ Args:
160
+ texture (np.ndarray): Texture image.
161
+ mask (np.ndarray): Mask image.
162
+ vertices (np.ndarray): Mesh vertices.
163
+ faces (np.ndarray): Mesh faces.
164
+ uv_map (np.ndarray): UV coordinates.
165
+
166
+ Returns:
167
+ tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask.
168
+ """
169
  image_h, image_w, C = texture.shape
170
  N = vertices.shape[0]
171
 
 
261
  class TextureBacker:
262
  """Texture baking pipeline for multi-view projection and fusion.
263
 
264
+ This class generates UV-based textures for a 3D mesh using multi-view images,
265
+ depth, and normal information. It includes mesh normalization, UV unwrapping,
266
+ visibility-aware back-projection, confidence-weighted fusion, and inpainting.
 
 
267
 
268
  Args:
269
+ camera_params (CameraSetting): Camera intrinsics and extrinsics.
270
+ view_weights (list[float]): Weights for each view in texture fusion.
271
+ render_wh (tuple[int, int], optional): Intermediate rendering resolution.
272
+ texture_wh (tuple[int, int], optional): Output texture resolution.
273
+ bake_angle_thresh (int, optional): Max angle for valid projection.
274
+ mask_thresh (float, optional): Threshold for visibility masks.
275
+ smooth_texture (bool, optional): Apply post-processing to texture.
276
+ inpaint_smooth (bool, optional): Apply inpainting smoothing.
277
+
278
+ Example:
279
+ ```py
280
+ from embodied_gen.data.backproject_v2 import TextureBacker
281
+ from embodied_gen.data.utils import CameraSetting
282
+ import trimesh
283
+ from PIL import Image
284
+
285
+ camera_params = CameraSetting(
286
+ num_images=6,
287
+ elevation=[20, -10],
288
+ distance=5,
289
+ resolution_hw=(2048,2048),
290
+ fov=math.radians(30),
291
+ device='cuda',
292
+ )
293
+ view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
294
+ mesh = trimesh.load('mesh.obj')
295
+ images = [Image.open(f'view_{i}.png') for i in range(6)]
296
+ texture_backer = TextureBacker(camera_params, view_weights)
297
+ textured_mesh = texture_backer(images, mesh, 'output.obj')
298
+ ```
299
  """
300
 
301
  def __init__(
 
325
  )
326
 
327
  def _lazy_init_render(self, camera_params, mask_thresh):
328
+ """Lazily initializes the renderer.
329
+
330
+ Args:
331
+ camera_params (CameraSetting): Camera settings.
332
+ mask_thresh (float): Mask threshold.
333
+ """
334
  if self.renderer is None:
335
  camera = init_kal_camera(camera_params)
336
  mv = camera.view_matrix() # (n 4 4) world2cam
 
349
  )
350
 
351
  def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
352
+ """Normalizes mesh and unwraps UVs.
353
+
354
+ Args:
355
+ mesh (trimesh.Trimesh): Input mesh.
356
+
357
+ Returns:
358
+ trimesh.Trimesh: Mesh with normalized vertices and UVs.
359
+ """
360
  mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
361
  self.scale, self.center = scale, center
362
 
 
374
  scale: float = None,
375
  center: np.ndarray = None,
376
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
377
+ """Gets mesh attributes as numpy arrays.
378
+
379
+ Args:
380
+ mesh (trimesh.Trimesh): Input mesh.
381
+ scale (float, optional): Scale factor.
382
+ center (np.ndarray, optional): Center offset.
383
+
384
+ Returns:
385
+ tuple: (vertices, faces, uv_map)
386
+ """
387
  vertices = mesh.vertices.copy()
388
  faces = mesh.faces.copy()
389
  uv_map = mesh.visual.uv.copy()
 
397
  return vertices, faces, uv_map
398
 
399
  def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
400
+ """Computes edge image from depth map.
401
+
402
+ Args:
403
+ depth_image (torch.Tensor): Depth map.
404
+
405
+ Returns:
406
+ torch.Tensor: Edge image.
407
+ """
408
  depth_image_np = depth_image.cpu().numpy()
409
  depth_image_np = (depth_image_np * 255).astype(np.uint8)
410
  depth_edges = cv2.Canny(depth_image_np, 30, 80)
 
418
  def compute_enhanced_viewnormal(
419
  self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
420
  ) -> torch.Tensor:
421
+ """Computes enhanced view normals for mesh faces.
422
+
423
+ Args:
424
+ mv_mtx (torch.Tensor): View matrices.
425
+ vertices (torch.Tensor): Mesh vertices.
426
+ faces (torch.Tensor): Mesh faces.
427
+
428
+ Returns:
429
+ torch.Tensor: View normals.
430
+ """
431
  rast, _ = self.renderer.compute_dr_raster(vertices, faces)
432
  rendered_view_normals = []
433
  for idx in range(len(mv_mtx)):
 
460
  def back_project(
461
  self, image, vis_mask, depth, normal, uv
462
  ) -> tuple[torch.Tensor, torch.Tensor]:
463
+ """Back-projects image and confidence to UV texture space.
464
+
465
+ Args:
466
+ image (PIL.Image or np.ndarray): Input image.
467
+ vis_mask (torch.Tensor): Visibility mask.
468
+ depth (torch.Tensor): Depth map.
469
+ normal (torch.Tensor): Normal map.
470
+ uv (torch.Tensor): UV coordinates.
471
+
472
+ Returns:
473
+ tuple[torch.Tensor, torch.Tensor]: Texture and confidence map.
474
+ """
475
  image = np.array(image)
476
  image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
477
  if image.ndim == 2:
 
514
  )
515
 
516
  def _scatter_texture(self, uv, data, mask):
517
+ """Scatters data to texture using UV coordinates and mask.
518
+
519
+ Args:
520
+ uv (torch.Tensor): UV coordinates.
521
+ data (torch.Tensor): Data to scatter.
522
+ mask (torch.Tensor): Mask for valid pixels.
523
+
524
+ Returns:
525
+ torch.Tensor: Scattered texture.
526
+ """
527
+
528
  def __filter_data(data, mask):
529
  return data.view(-1, data.shape[-1])[mask]
530
 
 
539
  def fast_bake_texture(
540
  self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
541
  ) -> tuple[torch.Tensor, torch.Tensor]:
542
+ """Fuses multiple textures and confidence maps.
543
+
544
+ Args:
545
+ textures (list[torch.Tensor]): List of textures.
546
+ confidence_maps (list[torch.Tensor]): List of confidence maps.
547
+
548
+ Returns:
549
+ tuple[torch.Tensor, torch.Tensor]: Fused texture and mask.
550
+ """
551
  channel = textures[0].shape[-1]
552
  texture_merge = torch.zeros(self.texture_wh + [channel]).to(
553
  self.device
 
567
  def uv_inpaint(
568
  self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
569
  ) -> np.ndarray:
570
+ """Inpaints missing regions in the UV texture.
571
+
572
+ Args:
573
+ mesh (trimesh.Trimesh): Mesh.
574
+ texture (np.ndarray): Texture image.
575
+ mask (np.ndarray): Mask image.
576
+
577
+ Returns:
578
+ np.ndarray: Inpainted texture.
579
+ """
580
  if self.inpaint_smooth:
581
  vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
582
  texture, mask = _texture_inpaint_smooth(
 
599
  colors: list[Image.Image],
600
  mesh: trimesh.Trimesh,
601
  ) -> trimesh.Trimesh:
602
+ """Computes the fused texture for the mesh from multi-view images.
603
+
604
+ Args:
605
+ colors (list[Image.Image]): List of view images.
606
+ mesh (trimesh.Trimesh): Mesh to texture.
607
+
608
+ Returns:
609
+ tuple[np.ndarray, np.ndarray]: Texture and mask.
610
+ """
611
  self._lazy_init_render(self.camera_params, self.mask_thresh)
612
 
613
  vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
 
652
  Args:
653
  colors (list[Image.Image]): List of input view images.
654
  mesh (trimesh.Trimesh): Input mesh to be textured.
655
+ output_path (str): Path to save the output textured mesh.
656
 
657
  Returns:
658
  trimesh.Trimesh: The textured mesh with UV and texture image.
 
675
 
676
 
677
  def parse_args():
678
+ """Parses command-line arguments for texture backprojection.
679
+
680
+ Returns:
681
+ argparse.Namespace: Parsed arguments.
682
+ """
683
  parser = argparse.ArgumentParser(description="Backproject texture")
684
  parser.add_argument(
685
  "--color_path",
 
776
  imagesr_model: ImageRealESRGAN = None,
777
  **kwargs,
778
  ) -> trimesh.Trimesh:
779
+ """Entrypoint for texture backprojection from multi-view images.
780
+
781
+ Args:
782
+ delight_model (DelightingModel, optional): Delighting model.
783
+ imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
784
+ **kwargs: Additional arguments to override CLI.
785
+
786
+ Returns:
787
+ trimesh.Trimesh: Textured mesh.
788
+ """
789
  args = parse_args()
790
  for k, v in kwargs.items():
791
  if hasattr(args, k) and v is not None:
embodied_gen/data/backproject_v3.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import logging
20
+ import math
21
+ from typing import Literal, Union
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import nvdiffrast.torch as dr
26
+ import spaces
27
+ import torch
28
+ import trimesh
29
+ import utils3d
30
+ import xatlas
31
+ from PIL import Image
32
+ from tqdm import tqdm
33
+ from embodied_gen.data.mesh_operator import MeshFixer
34
+ from embodied_gen.data.utils import (
35
+ CameraSetting,
36
+ init_kal_camera,
37
+ kaolin_to_opencv_view,
38
+ normalize_vertices_array,
39
+ post_process_texture,
40
+ save_mesh_with_mtl,
41
+ )
42
+ from embodied_gen.models.delight_model import DelightingModel
43
+ from embodied_gen.models.gs_model import load_gs_model
44
+ from embodied_gen.models.sr_model import ImageRealESRGAN
45
+
46
+ logging.basicConfig(
47
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
48
+ )
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ __all__ = [
53
+ "TextureBaker",
54
+ ]
55
+
56
+
57
+ class TextureBaker(object):
58
+ """Baking textures onto a mesh from multiple observations.
59
+
60
+ This class take 3D mesh data, camera settings and texture baking parameters
61
+ to generate texture map by projecting images to the mesh from diff views.
62
+ It supports both a fast texture baking approach and a more optimized method
63
+ with total variation regularization.
64
+
65
+ Attributes:
66
+ vertices (torch.Tensor): The vertices of the mesh.
67
+ faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
68
+ uvs (torch.Tensor): The UV coordinates of the mesh.
69
+ camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
70
+ device (str): The device to run computations on ("cpu" or "cuda").
71
+ w2cs (torch.Tensor): World-to-camera transformation matrices.
72
+ projections (torch.Tensor): Camera projection matrices.
73
+
74
+ Example:
75
+ >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
76
+ >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
77
+ >>> images = get_images_from_grid(args.color_path, image_size)
78
+ >>> texture = texture_backer.bake_texture(
79
+ ... images, texture_size=args.texture_size, mode=args.baker_mode
80
+ ... )
81
+ >>> texture = post_process_texture(texture)
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ vertices: np.ndarray,
87
+ faces: np.ndarray,
88
+ uvs: np.ndarray,
89
+ camera_params: CameraSetting,
90
+ device: str = "cuda",
91
+ ) -> None:
92
+ self.vertices = (
93
+ torch.tensor(vertices, device=device)
94
+ if isinstance(vertices, np.ndarray)
95
+ else vertices.to(device)
96
+ )
97
+ self.faces = (
98
+ torch.tensor(faces.astype(np.int32), device=device)
99
+ if isinstance(faces, np.ndarray)
100
+ else faces.to(device)
101
+ )
102
+ self.uvs = (
103
+ torch.tensor(uvs, device=device)
104
+ if isinstance(uvs, np.ndarray)
105
+ else uvs.to(device)
106
+ )
107
+ self.camera_params = camera_params
108
+ self.device = device
109
+
110
+ camera = init_kal_camera(camera_params)
111
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
112
+ matrix_mv = kaolin_to_opencv_view(matrix_mv)
113
+ matrix_p = (
114
+ camera.intrinsics.projection_matrix()
115
+ ) # (n_cam 4 4) cam2pixel
116
+ self.w2cs = matrix_mv.to(self.device)
117
+ self.projections = matrix_p.to(self.device)
118
+
119
+ @staticmethod
120
+ def parametrize_mesh(
121
+ vertices: np.array, faces: np.array
122
+ ) -> Union[np.array, np.array, np.array]:
123
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
124
+
125
+ vertices = vertices[vmapping]
126
+ faces = indices
127
+
128
+ return vertices, faces, uvs
129
+
130
+ def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
131
+ texture = torch.zeros(
132
+ (texture_size * texture_size, 3), dtype=torch.float32
133
+ ).cuda()
134
+ texture_weights = torch.zeros(
135
+ (texture_size * texture_size), dtype=torch.float32
136
+ ).cuda()
137
+ rastctx = utils3d.torch.RastContext(backend="cuda")
138
+ for observation, w2c, projection in tqdm(
139
+ zip(observations, w2cs, projections),
140
+ total=len(observations),
141
+ desc="Texture baking (fast)",
142
+ ):
143
+ with torch.no_grad():
144
+ rast = utils3d.torch.rasterize_triangle_faces(
145
+ rastctx,
146
+ self.vertices[None],
147
+ self.faces,
148
+ observation.shape[1],
149
+ observation.shape[0],
150
+ uv=self.uvs[None],
151
+ view=w2c,
152
+ projection=projection,
153
+ )
154
+ uv_map = rast["uv"][0].detach().flip(0)
155
+ mask = rast["mask"][0].detach().bool() & masks[0]
156
+
157
+ # nearest neighbor interpolation
158
+ uv_map = (uv_map * texture_size).floor().long()
159
+ obs = observation[mask]
160
+ uv_map = uv_map[mask]
161
+ idx = (
162
+ uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
163
+ )
164
+ texture = texture.scatter_add(
165
+ 0, idx.view(-1, 1).expand(-1, 3), obs
166
+ )
167
+ texture_weights = texture_weights.scatter_add(
168
+ 0,
169
+ idx,
170
+ torch.ones(
171
+ (obs.shape[0]), dtype=torch.float32, device=texture.device
172
+ ),
173
+ )
174
+
175
+ mask = texture_weights > 0
176
+ texture[mask] /= texture_weights[mask][:, None]
177
+ texture = np.clip(
178
+ texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
179
+ 0,
180
+ 255,
181
+ ).astype(np.uint8)
182
+
183
+ # inpaint
184
+ mask = (
185
+ (texture_weights == 0)
186
+ .cpu()
187
+ .numpy()
188
+ .astype(np.uint8)
189
+ .reshape(texture_size, texture_size)
190
+ )
191
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
192
+
193
+ return texture
194
+
195
+ def _bake_opt(
196
+ self,
197
+ observations,
198
+ w2cs,
199
+ projections,
200
+ texture_size,
201
+ lambda_tv,
202
+ masks,
203
+ total_steps,
204
+ ):
205
+ rastctx = utils3d.torch.RastContext(backend="cuda")
206
+ observations = [observations.flip(0) for observations in observations]
207
+ masks = [m.flip(0) for m in masks]
208
+ _uv = []
209
+ _uv_dr = []
210
+ for observation, w2c, projection in tqdm(
211
+ zip(observations, w2cs, projections),
212
+ total=len(w2cs),
213
+ ):
214
+ with torch.no_grad():
215
+ rast = utils3d.torch.rasterize_triangle_faces(
216
+ rastctx,
217
+ self.vertices[None],
218
+ self.faces,
219
+ observation.shape[1],
220
+ observation.shape[0],
221
+ uv=self.uvs[None],
222
+ view=w2c,
223
+ projection=projection,
224
+ )
225
+ _uv.append(rast["uv"].detach())
226
+ _uv_dr.append(rast["uv_dr"].detach())
227
+
228
+ texture = torch.nn.Parameter(
229
+ torch.zeros(
230
+ (1, texture_size, texture_size, 3), dtype=torch.float32
231
+ ).cuda()
232
+ )
233
+ optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
234
+
235
+ def cosine_anealing(step, total_steps, start_lr, end_lr):
236
+ return end_lr + 0.5 * (start_lr - end_lr) * (
237
+ 1 + np.cos(np.pi * step / total_steps)
238
+ )
239
+
240
+ def tv_loss(texture):
241
+ return torch.nn.functional.l1_loss(
242
+ texture[:, :-1, :, :], texture[:, 1:, :, :]
243
+ ) + torch.nn.functional.l1_loss(
244
+ texture[:, :, :-1, :], texture[:, :, 1:, :]
245
+ )
246
+
247
+ with tqdm(total=total_steps, desc="Texture baking") as pbar:
248
+ for step in range(total_steps):
249
+ optimizer.zero_grad()
250
+ selected = np.random.randint(0, len(w2cs))
251
+ uv, uv_dr, observation, mask = (
252
+ _uv[selected],
253
+ _uv_dr[selected],
254
+ observations[selected],
255
+ masks[selected],
256
+ )
257
+ render = dr.texture(texture, uv, uv_dr)[0]
258
+ loss = torch.nn.functional.l1_loss(
259
+ render[mask], observation[mask]
260
+ )
261
+ if lambda_tv > 0:
262
+ loss += lambda_tv * tv_loss(texture)
263
+ loss.backward()
264
+ optimizer.step()
265
+
266
+ optimizer.param_groups[0]["lr"] = cosine_anealing(
267
+ step, total_steps, 1e-2, 1e-5
268
+ )
269
+ pbar.set_postfix({"loss": loss.item()})
270
+ pbar.update()
271
+
272
+ texture = np.clip(
273
+ texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
274
+ ).astype(np.uint8)
275
+ mask = 1 - utils3d.torch.rasterize_triangle_faces(
276
+ rastctx,
277
+ (self.uvs * 2 - 1)[None],
278
+ self.faces,
279
+ texture_size,
280
+ texture_size,
281
+ )["mask"][0].detach().cpu().numpy().astype(np.uint8)
282
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
283
+
284
+ return texture
285
+
286
+ def bake_texture(
287
+ self,
288
+ images: list[np.array],
289
+ texture_size: int = 1024,
290
+ mode: Literal["fast", "opt"] = "opt",
291
+ lambda_tv: float = 1e-2,
292
+ opt_step: int = 2000,
293
+ ):
294
+ masks = [np.any(img > 0, axis=-1) for img in images]
295
+ masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
296
+ images = [
297
+ torch.tensor(obs / 255.0).float().to(self.device) for obs in images
298
+ ]
299
+
300
+ if mode == "fast":
301
+ return self._bake_fast(
302
+ images, self.w2cs, self.projections, texture_size, masks
303
+ )
304
+ elif mode == "opt":
305
+ return self._bake_opt(
306
+ images,
307
+ self.w2cs,
308
+ self.projections,
309
+ texture_size,
310
+ lambda_tv,
311
+ masks,
312
+ opt_step,
313
+ )
314
+ else:
315
+ raise ValueError(f"Unknown mode: {mode}")
316
+
317
+
318
+ def parse_args():
319
+ """Parses command-line arguments for texture backprojection.
320
+
321
+ Returns:
322
+ argparse.Namespace: Parsed arguments.
323
+ """
324
+ parser = argparse.ArgumentParser(description="Backproject texture")
325
+ parser.add_argument(
326
+ "--gs_path",
327
+ type=str,
328
+ help="Path to the GS.ply gaussian splatting model",
329
+ )
330
+ parser.add_argument(
331
+ "--mesh_path",
332
+ type=str,
333
+ help="Mesh path, .obj, .glb or .ply",
334
+ )
335
+ parser.add_argument(
336
+ "--output_path",
337
+ type=str,
338
+ help="Output mesh path with suffix",
339
+ )
340
+ parser.add_argument(
341
+ "--num_images",
342
+ type=int,
343
+ default=180,
344
+ help="Number of images to render.",
345
+ )
346
+ parser.add_argument(
347
+ "--elevation",
348
+ nargs="+",
349
+ type=float,
350
+ default=list(range(85, -90, -10)),
351
+ help="Elevation angles for the camera",
352
+ )
353
+ parser.add_argument(
354
+ "--distance",
355
+ type=float,
356
+ default=5,
357
+ help="Camera distance (default: 5)",
358
+ )
359
+ parser.add_argument(
360
+ "--resolution_hw",
361
+ type=int,
362
+ nargs=2,
363
+ default=(512, 512),
364
+ help="Resolution of the render images (default: (512, 512))",
365
+ )
366
+ parser.add_argument(
367
+ "--fov",
368
+ type=float,
369
+ default=30,
370
+ help="Field of view in degrees (default: 30)",
371
+ )
372
+ parser.add_argument(
373
+ "--device",
374
+ type=str,
375
+ choices=["cpu", "cuda"],
376
+ default="cuda",
377
+ help="Device to run on (default: `cuda`)",
378
+ )
379
+ parser.add_argument(
380
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
381
+ )
382
+ parser.add_argument(
383
+ "--texture_size",
384
+ type=int,
385
+ default=2048,
386
+ help="Texture size for texture baking (default: 1024)",
387
+ )
388
+ parser.add_argument(
389
+ "--baker_mode",
390
+ type=str,
391
+ default="opt",
392
+ help="Texture baking mode, `fast` or `opt` (default: opt)",
393
+ )
394
+ parser.add_argument(
395
+ "--opt_step",
396
+ type=int,
397
+ default=3000,
398
+ help="Optimization steps for texture baking (default: 3000)",
399
+ )
400
+ parser.add_argument(
401
+ "--mesh_sipmlify_ratio",
402
+ type=float,
403
+ default=0.9,
404
+ help="Mesh simplification ratio (default: 0.9)",
405
+ )
406
+ parser.add_argument(
407
+ "--delight", action="store_true", help="Use delighting model."
408
+ )
409
+ parser.add_argument(
410
+ "--no_smooth_texture",
411
+ action="store_true",
412
+ help="Do not smooth the texture.",
413
+ )
414
+ parser.add_argument(
415
+ "--no_coor_trans",
416
+ action="store_true",
417
+ help="Do not transform the asset coordinate system.",
418
+ )
419
+ parser.add_argument(
420
+ "--save_glb_path", type=str, default=None, help="Save glb path."
421
+ )
422
+ parser.add_argument("--n_max_faces", type=int, default=30000)
423
+ args, unknown = parser.parse_known_args()
424
+
425
+ return args
426
+
427
+
428
+ def entrypoint(
429
+ delight_model: DelightingModel = None,
430
+ imagesr_model: ImageRealESRGAN = None,
431
+ **kwargs,
432
+ ) -> trimesh.Trimesh:
433
+ """Entrypoint for texture backprojection from multi-view images.
434
+
435
+ Args:
436
+ delight_model (DelightingModel, optional): Delighting model.
437
+ imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
438
+ **kwargs: Additional arguments to override CLI.
439
+
440
+ Returns:
441
+ trimesh.Trimesh: Textured mesh.
442
+ """
443
+ args = parse_args()
444
+ for k, v in kwargs.items():
445
+ if hasattr(args, k) and v is not None:
446
+ setattr(args, k, v)
447
+
448
+ # Setup camera parameters.
449
+ camera_params = CameraSetting(
450
+ num_images=args.num_images,
451
+ elevation=args.elevation,
452
+ distance=args.distance,
453
+ resolution_hw=args.resolution_hw,
454
+ fov=math.radians(args.fov),
455
+ device=args.device,
456
+ )
457
+
458
+ # GS render.
459
+ camera = init_kal_camera(camera_params, flip_az=True)
460
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
461
+ matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
462
+ w2cs = matrix_mv.to(camera_params.device)
463
+ c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
464
+ Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
465
+ gs_model = load_gs_model(args.gs_path, pre_quat=[0.0, 0.0, 1.0, 0.0])
466
+ multiviews = []
467
+ for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
468
+ result = gs_model.render(
469
+ c2ws[idx],
470
+ Ks=Ks,
471
+ image_width=camera_params.resolution_hw[1],
472
+ image_height=camera_params.resolution_hw[0],
473
+ )
474
+ color = cv2.cvtColor(result.rgba, cv2.COLOR_BGRA2RGBA)
475
+ multiviews.append(Image.fromarray(color))
476
+
477
+ if args.delight and delight_model is None:
478
+ delight_model = DelightingModel()
479
+
480
+ if args.delight:
481
+ for idx in range(len(multiviews)):
482
+ multiviews[idx] = delight_model(multiviews[idx])
483
+
484
+ multiviews = [img.convert("RGB") for img in multiviews]
485
+
486
+ mesh = trimesh.load(args.mesh_path)
487
+ if isinstance(mesh, trimesh.Scene):
488
+ mesh = mesh.dump(concatenate=True)
489
+
490
+ vertices, scale, center = normalize_vertices_array(mesh.vertices)
491
+
492
+ # Transform mesh coordinate system by default.
493
+ if not args.no_coor_trans:
494
+ x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
495
+ z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
496
+ vertices = vertices @ x_rot
497
+ vertices = vertices @ z_rot
498
+
499
+ faces = mesh.faces.astype(np.int32)
500
+ vertices = vertices.astype(np.float32)
501
+
502
+ if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
503
+ mesh_fixer = MeshFixer(vertices, faces, args.device)
504
+ vertices, faces = mesh_fixer(
505
+ filter_ratio=args.mesh_sipmlify_ratio,
506
+ max_hole_size=0.04,
507
+ resolution=1024,
508
+ num_views=1000,
509
+ norm_mesh_ratio=0.5,
510
+ )
511
+ if len(faces) > args.n_max_faces:
512
+ mesh_fixer = MeshFixer(vertices, faces, args.device)
513
+ vertices, faces = mesh_fixer(
514
+ filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
515
+ max_hole_size=0.04,
516
+ resolution=1024,
517
+ num_views=1000,
518
+ norm_mesh_ratio=0.5,
519
+ )
520
+
521
+ vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
522
+ texture_backer = TextureBaker(
523
+ vertices,
524
+ faces,
525
+ uvs,
526
+ camera_params,
527
+ )
528
+
529
+ multiviews = [np.array(img) for img in multiviews]
530
+ texture = texture_backer.bake_texture(
531
+ images=[img[..., :3] for img in multiviews],
532
+ texture_size=args.texture_size,
533
+ mode=args.baker_mode,
534
+ opt_step=args.opt_step,
535
+ )
536
+ if not args.no_smooth_texture:
537
+ texture = post_process_texture(texture)
538
+
539
+ # Recover mesh original orientation, scale and center.
540
+ if not args.no_coor_trans:
541
+ vertices = vertices @ np.linalg.inv(z_rot)
542
+ vertices = vertices @ np.linalg.inv(x_rot)
543
+ vertices = vertices / scale
544
+ vertices = vertices + center
545
+
546
+ textured_mesh = save_mesh_with_mtl(
547
+ vertices, faces, uvs, texture, args.output_path
548
+ )
549
+ if args.save_glb_path is not None:
550
+ os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
551
+ textured_mesh.export(args.save_glb_path)
552
+
553
+ return textured_mesh
554
+
555
+
556
+ if __name__ == "__main__":
557
+ entrypoint()
embodied_gen/data/convex_decomposer.py CHANGED
@@ -39,6 +39,22 @@ def decompose_convex_coacd(
39
  auto_scale: bool = True,
40
  scale_factor: float = 1.0,
41
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  coacd.set_log_level("info" if verbose else "warn")
43
 
44
  mesh = trimesh.load(filename, force="mesh")
@@ -83,7 +99,38 @@ def decompose_convex_mesh(
83
  scale_factor: float = 1.005,
84
  verbose: bool = False,
85
  ) -> str:
86
- """Decompose a mesh into convex parts using the CoACD algorithm."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  coacd.set_log_level("info" if verbose else "warn")
88
 
89
  if os.path.exists(outfile):
@@ -148,9 +195,37 @@ def decompose_convex_mp(
148
  verbose: bool = False,
149
  auto_scale: bool = True,
150
  ) -> str:
151
- """Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
 
 
 
 
 
152
 
153
  See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  """
155
  params = dict(
156
  threshold=threshold,
 
39
  auto_scale: bool = True,
40
  scale_factor: float = 1.0,
41
  ) -> None:
42
+ """Decomposes a mesh using CoACD and saves the result.
43
+
44
+ This function loads a mesh from a file, runs the CoACD algorithm with the
45
+ given parameters, optionally scales the resulting convex hulls to match the
46
+ original mesh's bounding box, and exports the combined result to a file.
47
+
48
+ Args:
49
+ filename: Path to the input mesh file.
50
+ outfile: Path to save the decomposed output mesh.
51
+ params: A dictionary of parameters for the CoACD algorithm.
52
+ verbose: If True, sets the CoACD log level to 'info'.
53
+ auto_scale: If True, automatically computes a scale factor to match the
54
+ decomposed mesh's bounding box to the visual mesh's bounding box.
55
+ scale_factor: An additional scaling factor applied to the vertices of
56
+ the decomposed mesh parts.
57
+ """
58
  coacd.set_log_level("info" if verbose else "warn")
59
 
60
  mesh = trimesh.load(filename, force="mesh")
 
99
  scale_factor: float = 1.005,
100
  verbose: bool = False,
101
  ) -> str:
102
+ """Decomposes a mesh into convex parts with retry logic.
103
+
104
+ This function serves as a wrapper for `decompose_convex_coacd`, providing
105
+ explicit parameters for the CoACD algorithm and implementing a retry
106
+ mechanism. If the initial decomposition fails, it attempts again with
107
+ `preprocess_mode` set to 'on'.
108
+
109
+ Args:
110
+ filename: Path to the input mesh file.
111
+ outfile: Path to save the decomposed output mesh.
112
+ threshold: CoACD parameter. See CoACD documentation for details.
113
+ max_convex_hull: CoACD parameter. See CoACD documentation for details.
114
+ preprocess_mode: CoACD parameter. See CoACD documentation for details.
115
+ preprocess_resolution: CoACD parameter. See CoACD documentation for details.
116
+ resolution: CoACD parameter. See CoACD documentation for details.
117
+ mcts_nodes: CoACD parameter. See CoACD documentation for details.
118
+ mcts_iterations: CoACD parameter. See CoACD documentation for details.
119
+ mcts_max_depth: CoACD parameter. See CoACD documentation for details.
120
+ pca: CoACD parameter. See CoACD documentation for details.
121
+ merge: CoACD parameter. See CoACD documentation for details.
122
+ seed: CoACD parameter. See CoACD documentation for details.
123
+ auto_scale: If True, automatically scale the output to match the input
124
+ bounding box.
125
+ scale_factor: Additional scaling factor to apply.
126
+ verbose: If True, enables detailed logging.
127
+
128
+ Returns:
129
+ The path to the output file if decomposition is successful.
130
+
131
+ Raises:
132
+ RuntimeError: If convex decomposition fails after all attempts.
133
+ """
134
  coacd.set_log_level("info" if verbose else "warn")
135
 
136
  if os.path.exists(outfile):
 
195
  verbose: bool = False,
196
  auto_scale: bool = True,
197
  ) -> str:
198
+ """Decomposes a mesh into convex parts in a separate process.
199
+
200
+ This function uses the `multiprocessing` module to run the CoACD algorithm
201
+ in a spawned subprocess. This is useful for isolating the decomposition
202
+ process to prevent potential memory leaks or crashes in the main process.
203
+ It includes a retry mechanism similar to `decompose_convex_mesh`.
204
 
205
  See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
206
+
207
+ Args:
208
+ filename: Path to the input mesh file.
209
+ outfile: Path to save the decomposed output mesh.
210
+ threshold: CoACD parameter.
211
+ max_convex_hull: CoACD parameter.
212
+ preprocess_mode: CoACD parameter.
213
+ preprocess_resolution: CoACD parameter.
214
+ resolution: CoACD parameter.
215
+ mcts_nodes: CoACD parameter.
216
+ mcts_iterations: CoACD parameter.
217
+ mcts_max_depth: CoACD parameter.
218
+ pca: CoACD parameter.
219
+ merge: CoACD parameter.
220
+ seed: CoACD parameter.
221
+ verbose: If True, enables detailed logging in the subprocess.
222
+ auto_scale: If True, automatically scale the output.
223
+
224
+ Returns:
225
+ The path to the output file if decomposition is successful.
226
+
227
+ Raises:
228
+ RuntimeError: If convex decomposition fails after all attempts.
229
  """
230
  params = dict(
231
  threshold=threshold,
embodied_gen/data/differentiable_render.py CHANGED
@@ -66,6 +66,14 @@ def create_mp4_from_images(
66
  fps: int = 10,
67
  prompt: str = None,
68
  ):
 
 
 
 
 
 
 
 
69
  font = cv2.FONT_HERSHEY_SIMPLEX
70
  font_scale = 0.5
71
  font_thickness = 1
@@ -96,6 +104,13 @@ def create_mp4_from_images(
96
  def create_gif_from_images(
97
  images: list[np.ndarray], output_path: str, fps: int = 10
98
  ) -> None:
 
 
 
 
 
 
 
99
  pil_images = []
100
  for image in images:
101
  image = image.clip(min=0, max=1)
@@ -116,32 +131,47 @@ def create_gif_from_images(
116
 
117
 
118
  class ImageRender(object):
119
- """A differentiable mesh renderer supporting multi-view rendering.
120
 
121
- This class wraps a differentiable rasterization using `nvdiffrast` to
122
- render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
 
123
 
124
  Args:
125
- render_items (list[RenderItems]): A list of rendering targets to
126
- generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
127
- camera_params (CameraSetting): The camera parameters for rendering,
128
- including intrinsic and extrinsic matrices.
129
- recompute_vtx_normal (bool, optional): If True, recomputes
130
- vertex normals from the mesh geometry. Defaults to True.
131
- with_mtl (bool, optional): Whether to load `.mtl` material files
132
- for meshes. Defaults to False.
133
- gen_color_gif (bool, optional): Generate a GIF of rendered
134
- color images. Defaults to False.
135
- gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
136
- color images. Defaults to False.
137
- gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
138
- view-space normals. Defaults to False.
139
- gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
140
- global-space normals. Defaults to False.
141
- no_index_file (bool, optional): If True, skip saving the `index.json`
142
- summary file. Defaults to False.
143
- light_factor (float, optional): A scalar multiplier for
144
- PBR light intensity. Defaults to 1.0.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
 
147
  def __init__(
@@ -198,6 +228,14 @@ class ImageRender(object):
198
  uuid: Union[str, List[str]] = None,
199
  prompts: List[str] = None,
200
  ) -> None:
 
 
 
 
 
 
 
 
201
  mesh_path = as_list(mesh_path)
202
  if uuid is None:
203
  uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
@@ -227,18 +265,15 @@ class ImageRender(object):
227
  def __call__(
228
  self, mesh_path: str, output_dir: str, prompt: str = None
229
  ) -> dict[str, str]:
230
- """Render a single mesh and return paths to the rendered outputs.
231
-
232
- Processes the input mesh, renders multiple modalities (e.g., normals,
233
- depth, albedo), and optionally saves video or image sequences.
234
 
235
  Args:
236
- mesh_path (str): Path to the mesh file (.obj/.glb).
237
- output_dir (str): Directory to save rendered outputs.
238
- prompt (str, optional): Optional caption prompt for MP4 metadata.
239
 
240
  Returns:
241
- dict[str, str]: A mapping render types to the saved image paths.
242
  """
243
  try:
244
  mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
 
66
  fps: int = 10,
67
  prompt: str = None,
68
  ):
69
+ """Creates an MP4 video from a list of images.
70
+
71
+ Args:
72
+ images (list[np.ndarray]): List of images as numpy arrays.
73
+ output_path (str): Path to save the MP4 file.
74
+ fps (int, optional): Frames per second. Defaults to 10.
75
+ prompt (str, optional): Optional text prompt overlay.
76
+ """
77
  font = cv2.FONT_HERSHEY_SIMPLEX
78
  font_scale = 0.5
79
  font_thickness = 1
 
104
  def create_gif_from_images(
105
  images: list[np.ndarray], output_path: str, fps: int = 10
106
  ) -> None:
107
+ """Creates a GIF animation from a list of images.
108
+
109
+ Args:
110
+ images (list[np.ndarray]): List of images as numpy arrays.
111
+ output_path (str): Path to save the GIF file.
112
+ fps (int, optional): Frames per second. Defaults to 10.
113
+ """
114
  pil_images = []
115
  for image in images:
116
  image = image.clip(min=0, max=1)
 
131
 
132
 
133
  class ImageRender(object):
134
+ """Differentiable mesh renderer supporting multi-view rendering.
135
 
136
+ This class wraps differentiable rasterization using `nvdiffrast` to render mesh
137
+ geometry to various maps (normal, depth, alpha, albedo, etc.) and supports
138
+ saving images and videos.
139
 
140
  Args:
141
+ render_items (list[RenderItems]): List of rendering targets.
142
+ camera_params (CameraSetting): Camera parameters for rendering.
143
+ recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True.
144
+ with_mtl (bool, optional): Load mesh material files. Defaults to False.
145
+ gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False.
146
+ gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False.
147
+ gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False.
148
+ gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False.
149
+ no_index_file (bool, optional): Skip saving index file. Defaults to False.
150
+ light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0.
151
+
152
+ Example:
153
+ ```py
154
+ from embodied_gen.data.differentiable_render import ImageRender
155
+ from embodied_gen.data.utils import CameraSetting
156
+ from embodied_gen.utils.enum import RenderItems
157
+
158
+ camera_params = CameraSetting(
159
+ num_images=6,
160
+ elevation=[20, -10],
161
+ distance=5,
162
+ resolution_hw=(512,512),
163
+ fov=math.radians(30),
164
+ device='cuda',
165
+ )
166
+ render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value]
167
+ renderer = ImageRender(
168
+ render_items,
169
+ camera_params,
170
+ with_mtl=args.with_mtl,
171
+ gen_color_mp4=True,
172
+ )
173
+ renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders')
174
+ ```
175
  """
176
 
177
  def __init__(
 
228
  uuid: Union[str, List[str]] = None,
229
  prompts: List[str] = None,
230
  ) -> None:
231
+ """Renders one or more meshes and saves outputs.
232
+
233
+ Args:
234
+ mesh_path (Union[str, List[str]]): Path(s) to mesh files.
235
+ output_root (str): Directory to save outputs.
236
+ uuid (Union[str, List[str]], optional): Unique IDs for outputs.
237
+ prompts (List[str], optional): Text prompts for videos.
238
+ """
239
  mesh_path = as_list(mesh_path)
240
  if uuid is None:
241
  uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
 
265
  def __call__(
266
  self, mesh_path: str, output_dir: str, prompt: str = None
267
  ) -> dict[str, str]:
268
+ """Renders a single mesh and returns output paths.
 
 
 
269
 
270
  Args:
271
+ mesh_path (str): Path to mesh file.
272
+ output_dir (str): Directory to save outputs.
273
+ prompt (str, optional): Caption prompt for MP4 metadata.
274
 
275
  Returns:
276
+ dict[str, str]: Mapping of render types to saved image paths.
277
  """
278
  try:
279
  mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
embodied_gen/data/mesh_operator.py CHANGED
@@ -16,17 +16,13 @@
16
 
17
 
18
  import logging
19
- import multiprocessing as mp
20
- import os
21
  from typing import Tuple, Union
22
 
23
- import coacd
24
  import igraph
25
  import numpy as np
26
  import pyvista as pv
27
  import spaces
28
  import torch
29
- import trimesh
30
  import utils3d
31
  from pymeshfix import _meshfix
32
  from tqdm import tqdm
 
16
 
17
 
18
  import logging
 
 
19
  from typing import Tuple, Union
20
 
 
21
  import igraph
22
  import numpy as np
23
  import pyvista as pv
24
  import spaces
25
  import torch
 
26
  import utils3d
27
  from pymeshfix import _meshfix
28
  from tqdm import tqdm
embodied_gen/data/utils.py CHANGED
@@ -66,6 +66,7 @@ __all__ = [
66
  "resize_pil",
67
  "trellis_preprocess",
68
  "delete_dir",
 
69
  ]
70
 
71
 
@@ -373,10 +374,18 @@ def _compute_az_el_by_views(
373
  def _compute_cam_pts_by_az_el(
374
  azs: np.ndarray,
375
  els: np.ndarray,
376
- distance: float,
377
  extra_pts: np.ndarray = None,
378
  ) -> np.ndarray:
379
- distances = np.array([distance for _ in range(len(azs))])
 
 
 
 
 
 
 
 
380
  cam_pts = _az_el_to_points(azs, els) * distances[:, None]
381
 
382
  if extra_pts is not None:
@@ -710,7 +719,7 @@ class CameraSetting:
710
 
711
  num_images: int
712
  elevation: list[float]
713
- distance: float
714
  resolution_hw: tuple[int, int]
715
  fov: float
716
  at: tuple[float, float, float] = field(
@@ -824,6 +833,28 @@ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
824
  return mesh
825
 
826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
  def save_mesh_with_mtl(
828
  vertices: np.ndarray,
829
  faces: np.ndarray,
 
66
  "resize_pil",
67
  "trellis_preprocess",
68
  "delete_dir",
69
+ "kaolin_to_opencv_view",
70
  ]
71
 
72
 
 
374
  def _compute_cam_pts_by_az_el(
375
  azs: np.ndarray,
376
  els: np.ndarray,
377
+ distance: float | list[float] | np.ndarray,
378
  extra_pts: np.ndarray = None,
379
  ) -> np.ndarray:
380
+ if np.isscalar(distance) or isinstance(distance, (float, int)):
381
+ distances = np.full(len(azs), distance)
382
+ else:
383
+ distances = np.array(distance)
384
+ if len(distances) != len(azs):
385
+ raise ValueError(
386
+ f"Length of distances ({len(distances)}) must match length of azs ({len(azs)})"
387
+ )
388
+
389
  cam_pts = _az_el_to_points(azs, els) * distances[:, None]
390
 
391
  if extra_pts is not None:
 
719
 
720
  num_images: int
721
  elevation: list[float]
722
+ distance: float | list[float]
723
  resolution_hw: tuple[int, int]
724
  fov: float
725
  at: tuple[float, float, float] = field(
 
833
  return mesh
834
 
835
 
836
+ def kaolin_to_opencv_view(raw_matrix):
837
+ R_orig = raw_matrix[:, :3, :3]
838
+ t_orig = raw_matrix[:, :3, 3]
839
+
840
+ R_target = torch.zeros_like(R_orig)
841
+ R_target[:, :, 0] = R_orig[:, :, 2]
842
+ R_target[:, :, 1] = R_orig[:, :, 0]
843
+ R_target[:, :, 2] = R_orig[:, :, 1]
844
+
845
+ t_target = t_orig
846
+
847
+ target_matrix = (
848
+ torch.eye(4, device=raw_matrix.device)
849
+ .unsqueeze(0)
850
+ .repeat(raw_matrix.size(0), 1, 1)
851
+ )
852
+ target_matrix[:, :3, :3] = R_target
853
+ target_matrix[:, :3, 3] = t_target
854
+
855
+ return target_matrix
856
+
857
+
858
  def save_mesh_with_mtl(
859
  vertices: np.ndarray,
860
  faces: np.ndarray,
embodied_gen/envs/pick_embodiedgen.py CHANGED
@@ -51,6 +51,33 @@ __all__ = ["PickEmbodiedGen"]
51
 
52
  @register_env("PickEmbodiedGen-v1", max_episode_steps=100)
53
  class PickEmbodiedGen(BaseEnv):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
55
  goal_thresh = 0.0
56
 
@@ -63,6 +90,19 @@ class PickEmbodiedGen(BaseEnv):
63
  reconfiguration_freq: int = None,
64
  **kwargs,
65
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  self.robot_init_qpos_noise = robot_init_qpos_noise
67
  if reconfiguration_freq is None:
68
  if num_envs == 1:
@@ -116,6 +156,22 @@ class PickEmbodiedGen(BaseEnv):
116
  def init_env_layouts(
117
  layout_file: str, num_envs: int, replace_objs: bool
118
  ) -> list[LayoutInfo]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  layouts = []
120
  for env_idx in range(num_envs):
121
  if replace_objs and env_idx > 0:
@@ -136,6 +192,18 @@ class PickEmbodiedGen(BaseEnv):
136
  def compute_robot_init_pose(
137
  layouts: list[str], num_envs: int, z_offset: float = 0.0
138
  ) -> list[list[float]]:
 
 
 
 
 
 
 
 
 
 
 
 
139
  robot_pose = []
140
  for env_idx in range(num_envs):
141
  layout = json.load(open(layouts[env_idx], "r"))
@@ -148,6 +216,11 @@ class PickEmbodiedGen(BaseEnv):
148
 
149
  @property
150
  def _default_sim_config(self):
 
 
 
 
 
151
  return SimConfig(
152
  scene_config=SceneConfig(
153
  solver_position_iterations=30,
@@ -163,6 +236,11 @@ class PickEmbodiedGen(BaseEnv):
163
 
164
  @property
165
  def _default_sensor_configs(self):
 
 
 
 
 
166
  pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
167
 
168
  return [
@@ -171,6 +249,11 @@ class PickEmbodiedGen(BaseEnv):
171
 
172
  @property
173
  def _default_human_render_camera_configs(self):
 
 
 
 
 
174
  pose = sapien_utils.look_at(
175
  eye=self.camera_cfg["camera_eye"],
176
  target=self.camera_cfg["camera_target_pt"],
@@ -187,10 +270,24 @@ class PickEmbodiedGen(BaseEnv):
187
  )
188
 
189
  def _load_agent(self, options: dict):
 
 
 
 
 
190
  self.ground = build_ground(self.scene)
191
  super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
192
 
193
  def _load_scene(self, options: dict):
 
 
 
 
 
 
 
 
 
194
  all_objects = []
195
  logger.info(f"Loading EmbodiedGen assets...")
196
  for env_idx in range(self.num_envs):
@@ -222,6 +319,15 @@ class PickEmbodiedGen(BaseEnv):
222
  self._hidden_objects.append(self.goal_site)
223
 
224
  def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
 
 
 
 
 
 
 
 
 
225
  with torch.device(self.device):
226
  b = len(env_idx)
227
  goal_xyz = torch.zeros((b, 3))
@@ -256,6 +362,21 @@ class PickEmbodiedGen(BaseEnv):
256
  def render_gs3d_images(
257
  self, layouts: list[str], num_envs: int, init_quat: list[float]
258
  ) -> dict[str, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  sim_coord_align = (
260
  torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
261
  )
@@ -293,6 +414,15 @@ class PickEmbodiedGen(BaseEnv):
293
  return bg_images
294
 
295
  def render(self):
 
 
 
 
 
 
 
 
 
296
  if self.render_mode is None:
297
  raise RuntimeError("render_mode is not set.")
298
  if self.render_mode == "human":
@@ -315,6 +445,17 @@ class PickEmbodiedGen(BaseEnv):
315
  def render_rgb_array(
316
  self, camera_name: str = None, return_alpha: bool = False
317
  ):
 
 
 
 
 
 
 
 
 
 
 
318
  for obj in self._hidden_objects:
319
  obj.show_visual()
320
  self.scene.update_render(
@@ -335,6 +476,11 @@ class PickEmbodiedGen(BaseEnv):
335
  return tile_images(images)
336
 
337
  def render_sensors(self):
 
 
 
 
 
338
  images = []
339
  sensor_images = self.get_sensor_images()
340
  for image in sensor_images.values():
@@ -343,6 +489,14 @@ class PickEmbodiedGen(BaseEnv):
343
  return tile_images(images)
344
 
345
  def hybrid_render(self):
 
 
 
 
 
 
 
 
346
  fg_images = self.render_rgb_array(
347
  return_alpha=True
348
  ) # (n_env, h, w, 3)
@@ -362,6 +516,16 @@ class PickEmbodiedGen(BaseEnv):
362
  return images[..., :3]
363
 
364
  def evaluate(self):
 
 
 
 
 
 
 
 
 
 
365
  obj_to_goal_pos = (
366
  self.obj.pose.p
367
  ) # self.goal_site.pose.p - self.obj.pose.p
@@ -381,10 +545,31 @@ class PickEmbodiedGen(BaseEnv):
381
  )
382
 
383
  def _get_obs_extra(self, info: dict):
 
 
 
 
 
 
 
 
384
 
385
  return dict()
386
 
387
  def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  tcp_to_obj_dist = torch.linalg.norm(
389
  self.obj.pose.p - self.agent.tcp.pose.p, axis=1
390
  )
@@ -417,4 +602,14 @@ class PickEmbodiedGen(BaseEnv):
417
  def compute_normalized_dense_reward(
418
  self, obs: any, action: torch.Tensor, info: dict
419
  ):
 
 
 
 
 
 
 
 
 
 
420
  return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
 
51
 
52
  @register_env("PickEmbodiedGen-v1", max_episode_steps=100)
53
  class PickEmbodiedGen(BaseEnv):
54
+ """PickEmbodiedGen as gym env example for object pick-and-place tasks.
55
+
56
+ This environment simulates a robot interacting with 3D assets in the
57
+ embodiedgen generated scene in SAPIEN. It supports multi-environment setups,
58
+ dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting.
59
+
60
+ Example:
61
+ Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment.
62
+ ```python
63
+ import gymnasium as gym
64
+ env = gym.make(
65
+ "PickEmbodiedGen-v1",
66
+ num_envs=cfg.num_envs,
67
+ render_mode=cfg.render_mode,
68
+ enable_shadow=cfg.enable_shadow,
69
+ layout_file=cfg.layout_file,
70
+ control_mode=cfg.control_mode,
71
+ camera_cfg=dict(
72
+ camera_eye=cfg.camera_eye,
73
+ camera_target_pt=cfg.camera_target_pt,
74
+ image_hw=cfg.image_hw,
75
+ fovy_deg=cfg.fovy_deg,
76
+ ),
77
+ )
78
+ ```
79
+ """
80
+
81
  SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
82
  goal_thresh = 0.0
83
 
 
90
  reconfiguration_freq: int = None,
91
  **kwargs,
92
  ):
93
+ """Initializes the PickEmbodiedGen environment.
94
+
95
+ Args:
96
+ *args: Variable length argument list for the base class.
97
+ robot_uids: The robot(s) to use in the environment.
98
+ robot_init_qpos_noise: Noise added to the robot's initial joint
99
+ positions.
100
+ num_envs: The number of parallel environments to create.
101
+ reconfiguration_freq: How often to reconfigure the scene. If None,
102
+ it is set based on num_envs.
103
+ **kwargs: Additional keyword arguments for environment setup,
104
+ including layout_file, replace_objs, enable_grasp, etc.
105
+ """
106
  self.robot_init_qpos_noise = robot_init_qpos_noise
107
  if reconfiguration_freq is None:
108
  if num_envs == 1:
 
156
  def init_env_layouts(
157
  layout_file: str, num_envs: int, replace_objs: bool
158
  ) -> list[LayoutInfo]:
159
+ """Initializes and saves layout files for each environment instance.
160
+
161
+ For each environment, this method creates a layout configuration. If
162
+ `replace_objs` is True, it generates new object placements for each
163
+ subsequent environment. The generated layouts are saved as new JSON
164
+ files.
165
+
166
+ Args:
167
+ layout_file: Path to the base layout JSON file.
168
+ num_envs: The number of environments to create layouts for.
169
+ replace_objs: If True, generates new object placements for each
170
+ environment after the first one using BFS placement.
171
+
172
+ Returns:
173
+ A list of file paths to the generated layout for each environment.
174
+ """
175
  layouts = []
176
  for env_idx in range(num_envs):
177
  if replace_objs and env_idx > 0:
 
192
  def compute_robot_init_pose(
193
  layouts: list[str], num_envs: int, z_offset: float = 0.0
194
  ) -> list[list[float]]:
195
+ """Computes the initial pose for the robot in each environment.
196
+
197
+ Args:
198
+ layouts: A list of file paths to the environment layouts.
199
+ num_envs: The number of environments.
200
+ z_offset: An optional vertical offset to apply to the robot's
201
+ position to prevent collisions.
202
+
203
+ Returns:
204
+ A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot
205
+ in each environment.
206
+ """
207
  robot_pose = []
208
  for env_idx in range(num_envs):
209
  layout = json.load(open(layouts[env_idx], "r"))
 
216
 
217
  @property
218
  def _default_sim_config(self):
219
+ """Returns the default simulation configuration.
220
+
221
+ Returns:
222
+ The default simulation configuration object.
223
+ """
224
  return SimConfig(
225
  scene_config=SceneConfig(
226
  solver_position_iterations=30,
 
236
 
237
  @property
238
  def _default_sensor_configs(self):
239
+ """Returns the default sensor configurations for the agent.
240
+
241
+ Returns:
242
+ A list containing the default camera configuration.
243
+ """
244
  pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
245
 
246
  return [
 
249
 
250
  @property
251
  def _default_human_render_camera_configs(self):
252
+ """Returns the default camera configuration for human-friendly rendering.
253
+
254
+ Returns:
255
+ The default camera configuration for the renderer.
256
+ """
257
  pose = sapien_utils.look_at(
258
  eye=self.camera_cfg["camera_eye"],
259
  target=self.camera_cfg["camera_target_pt"],
 
270
  )
271
 
272
  def _load_agent(self, options: dict):
273
+ """Loads the agent (robot) and a ground plane into the scene.
274
+
275
+ Args:
276
+ options: A dictionary of options for loading the agent.
277
+ """
278
  self.ground = build_ground(self.scene)
279
  super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
280
 
281
  def _load_scene(self, options: dict):
282
+ """Loads all assets, objects, and the goal site into the scene.
283
+
284
+ This method iterates through the layouts for each environment, loads the
285
+ specified assets, and adds them to the simulation. It also creates a
286
+ kinematic sphere to represent the goal site.
287
+
288
+ Args:
289
+ options: A dictionary of options for loading the scene.
290
+ """
291
  all_objects = []
292
  logger.info(f"Loading EmbodiedGen assets...")
293
  for env_idx in range(self.num_envs):
 
319
  self._hidden_objects.append(self.goal_site)
320
 
321
  def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
322
+ """Initializes an episode for a given set of environments.
323
+
324
+ This method sets the goal position, resets the robot's joint positions
325
+ with optional noise, and sets its root pose.
326
+
327
+ Args:
328
+ env_idx: A tensor of environment indices to initialize.
329
+ options: A dictionary of options for initialization.
330
+ """
331
  with torch.device(self.device):
332
  b = len(env_idx)
333
  goal_xyz = torch.zeros((b, 3))
 
362
  def render_gs3d_images(
363
  self, layouts: list[str], num_envs: int, init_quat: list[float]
364
  ) -> dict[str, np.ndarray]:
365
+ """Renders background images using a pre-trained Gaussian Splatting model.
366
+
367
+ This method pre-renders the static background for each environment from
368
+ the perspective of all cameras to be used for hybrid rendering.
369
+
370
+ Args:
371
+ layouts: A list of file paths to the environment layouts.
372
+ num_envs: The number of environments.
373
+ init_quat: An initial quaternion to orient the Gaussian Splatting
374
+ model.
375
+
376
+ Returns:
377
+ A dictionary mapping a unique key (e.g., 'camera-env_idx') to the
378
+ rendered background image as a numpy array.
379
+ """
380
  sim_coord_align = (
381
  torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
382
  )
 
414
  return bg_images
415
 
416
  def render(self):
417
+ """Renders the environment based on the configured render_mode.
418
+
419
+ Raises:
420
+ RuntimeError: If `render_mode` is not set.
421
+ NotImplementedError: If the `render_mode` is not supported.
422
+
423
+ Returns:
424
+ The rendered output, which varies depending on the render mode.
425
+ """
426
  if self.render_mode is None:
427
  raise RuntimeError("render_mode is not set.")
428
  if self.render_mode == "human":
 
445
  def render_rgb_array(
446
  self, camera_name: str = None, return_alpha: bool = False
447
  ):
448
+ """Renders an RGB image from the human-facing render camera.
449
+
450
+ Args:
451
+ camera_name: The name of the camera to render from. If None, uses
452
+ all human render cameras.
453
+ return_alpha: Whether to include the alpha channel in the output.
454
+
455
+ Returns:
456
+ A numpy array representing the rendered image(s). If multiple
457
+ cameras are used, the images are tiled.
458
+ """
459
  for obj in self._hidden_objects:
460
  obj.show_visual()
461
  self.scene.update_render(
 
476
  return tile_images(images)
477
 
478
  def render_sensors(self):
479
+ """Renders images from all on-board sensor cameras.
480
+
481
+ Returns:
482
+ A tiled image of all sensor outputs as a numpy array.
483
+ """
484
  images = []
485
  sensor_images = self.get_sensor_images()
486
  for image in sensor_images.values():
 
489
  return tile_images(images)
490
 
491
  def hybrid_render(self):
492
+ """Renders a hybrid image by blending simulated foreground with a background.
493
+
494
+ The foreground is rendered with an alpha channel and then blended with
495
+ the pre-rendered Gaussian Splatting background image.
496
+
497
+ Returns:
498
+ A torch tensor of the final blended RGB images.
499
+ """
500
  fg_images = self.render_rgb_array(
501
  return_alpha=True
502
  ) # (n_env, h, w, 3)
 
516
  return images[..., :3]
517
 
518
  def evaluate(self):
519
+ """Evaluates the current state of the environment.
520
+
521
+ Checks for task success criteria such as whether the object is grasped,
522
+ placed at the goal, and if the robot is static.
523
+
524
+ Returns:
525
+ A dictionary containing boolean tensors for various success
526
+ metrics, including 'is_grasped', 'is_obj_placed', and overall
527
+ 'success'.
528
+ """
529
  obj_to_goal_pos = (
530
  self.obj.pose.p
531
  ) # self.goal_site.pose.p - self.obj.pose.p
 
545
  )
546
 
547
  def _get_obs_extra(self, info: dict):
548
+ """Gets extra information for the observation dictionary.
549
+
550
+ Args:
551
+ info: A dictionary containing evaluation information.
552
+
553
+ Returns:
554
+ An empty dictionary, as no extra observations are added.
555
+ """
556
 
557
  return dict()
558
 
559
  def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
560
+ """Computes a dense reward for the current step.
561
+
562
+ The reward is a composite of reaching, grasping, placing, and
563
+ maintaining a static final pose.
564
+
565
+ Args:
566
+ obs: The current observation.
567
+ action: The action taken in the current step.
568
+ info: A dictionary containing evaluation information from `evaluate()`.
569
+
570
+ Returns:
571
+ A tensor containing the dense reward for each environment.
572
+ """
573
  tcp_to_obj_dist = torch.linalg.norm(
574
  self.obj.pose.p - self.agent.tcp.pose.p, axis=1
575
  )
 
602
  def compute_normalized_dense_reward(
603
  self, obs: any, action: torch.Tensor, info: dict
604
  ):
605
+ """Computes a dense reward normalized to be between 0 and 1.
606
+
607
+ Args:
608
+ obs: The current observation.
609
+ action: The action taken in the current step.
610
+ info: A dictionary containing evaluation information from `evaluate()`.
611
+
612
+ Returns:
613
+ A tensor containing the normalized dense reward for each environment.
614
+ """
615
  return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
embodied_gen/models/delight_model.py CHANGED
@@ -40,7 +40,7 @@ class DelightingModel(object):
40
  """A model to remove the lighting in image space.
41
 
42
  This model is encapsulated based on the Hunyuan3D-Delight model
43
- from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
44
 
45
  Attributes:
46
  image_guide_scale (float): Weight of image guidance in diffusion process.
 
40
  """A model to remove the lighting in image space.
41
 
42
  This model is encapsulated based on the Hunyuan3D-Delight model
43
+ from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa
44
 
45
  Attributes:
46
  image_guide_scale (float): Weight of image guidance in diffusion process.
embodied_gen/models/gs_model.py CHANGED
@@ -21,14 +21,18 @@ import struct
21
  from dataclasses import dataclass
22
  from typing import Optional
23
 
24
- import cv2
25
  import numpy as np
26
  import torch
27
  from gsplat.cuda._wrapper import spherical_harmonics
28
  from gsplat.rendering import rasterization
29
  from plyfile import PlyData
30
  from scipy.spatial.transform import Rotation
31
- from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat
 
 
 
 
 
32
 
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
@@ -494,6 +498,21 @@ class GaussianOperator(GaussianBase):
494
  )
495
 
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  if __name__ == "__main__":
498
  input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
499
  output_gs = "./gs_model.ply"
 
21
  from dataclasses import dataclass
22
  from typing import Optional
23
 
 
24
  import numpy as np
25
  import torch
26
  from gsplat.cuda._wrapper import spherical_harmonics
27
  from gsplat.rendering import rasterization
28
  from plyfile import PlyData
29
  from scipy.spatial.transform import Rotation
30
+ from embodied_gen.data.utils import (
31
+ gamma_shs,
32
+ normalize_vertices_array,
33
+ quat_mult,
34
+ quat_to_rotmat,
35
+ )
36
 
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
 
498
  )
499
 
500
 
501
+ def load_gs_model(
502
+ input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
503
+ ) -> GaussianOperator:
504
+ gs_model = GaussianOperator.load_from_ply(input_gs)
505
+ # Normalize vertices to [-1, 1], center to (0, 0, 0).
506
+ _, scale, center = normalize_vertices_array(gs_model._means)
507
+ scale, center = float(scale), center.tolist()
508
+ transpose = [*[v for v in center], *pre_quat]
509
+ instance_pose = torch.tensor(transpose).to(gs_model.device)
510
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
511
+ gs_model.rescale(scale)
512
+
513
+ return gs_model
514
+
515
+
516
  if __name__ == "__main__":
517
  input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
518
  output_gs = "./gs_model.ply"
embodied_gen/models/image_comm_model.py CHANGED
@@ -38,26 +38,61 @@ __all__ = [
38
 
39
 
40
  class BasePipelineLoader(ABC):
 
 
 
 
 
 
 
 
 
41
  def __init__(self, device="cuda"):
42
  self.device = device
43
 
44
  @abstractmethod
45
  def load(self):
 
46
  pass
47
 
48
 
49
  class BasePipelineRunner(ABC):
 
 
 
 
 
 
 
 
 
50
  def __init__(self, pipe):
51
  self.pipe = pipe
52
 
53
  @abstractmethod
54
  def run(self, prompt: str, **kwargs) -> Image.Image:
 
 
 
 
 
 
 
 
 
55
  pass
56
 
57
 
58
  # ===== SD3.5-medium =====
59
  class SD35Loader(BasePipelineLoader):
 
 
60
  def load(self):
 
 
 
 
 
61
  pipe = StableDiffusion3Pipeline.from_pretrained(
62
  "stabilityai/stable-diffusion-3.5-medium",
63
  torch_dtype=torch.float16,
@@ -70,12 +105,25 @@ class SD35Loader(BasePipelineLoader):
70
 
71
 
72
  class SD35Runner(BasePipelineRunner):
 
 
73
  def run(self, prompt: str, **kwargs) -> Image.Image:
 
 
 
 
 
 
 
 
 
74
  return self.pipe(prompt=prompt, **kwargs).images
75
 
76
 
77
  # ===== Cosmos2 =====
78
  class CosmosLoader(BasePipelineLoader):
 
 
79
  def __init__(
80
  self,
81
  model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
@@ -87,6 +135,8 @@ class CosmosLoader(BasePipelineLoader):
87
  self.local_dir = local_dir
88
 
89
  def _patch(self):
 
 
90
  def patch_model(cls):
91
  orig = cls.from_pretrained
92
 
@@ -110,6 +160,11 @@ class CosmosLoader(BasePipelineLoader):
110
  patch_processor(SiglipProcessor)
111
 
112
  def load(self):
 
 
 
 
 
113
  self._patch()
114
  snapshot_download(
115
  repo_id=self.model_id,
@@ -141,7 +196,19 @@ class CosmosLoader(BasePipelineLoader):
141
 
142
 
143
  class CosmosRunner(BasePipelineRunner):
 
 
144
  def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
145
  return self.pipe(
146
  prompt=prompt, negative_prompt=negative_prompt, **kwargs
147
  ).images
@@ -149,7 +216,14 @@ class CosmosRunner(BasePipelineRunner):
149
 
150
  # ===== Kolors =====
151
  class KolorsLoader(BasePipelineLoader):
 
 
152
  def load(self):
 
 
 
 
 
153
  pipe = KolorsPipeline.from_pretrained(
154
  "Kwai-Kolors/Kolors-diffusers",
155
  torch_dtype=torch.float16,
@@ -164,13 +238,31 @@ class KolorsLoader(BasePipelineLoader):
164
 
165
 
166
  class KolorsRunner(BasePipelineRunner):
 
 
167
  def run(self, prompt: str, **kwargs) -> Image.Image:
 
 
 
 
 
 
 
 
 
168
  return self.pipe(prompt=prompt, **kwargs).images
169
 
170
 
171
  # ===== Flux =====
172
  class FluxLoader(BasePipelineLoader):
 
 
173
  def load(self):
 
 
 
 
 
174
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
175
  pipe = FluxPipeline.from_pretrained(
176
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
@@ -182,20 +274,50 @@ class FluxLoader(BasePipelineLoader):
182
 
183
 
184
  class FluxRunner(BasePipelineRunner):
 
 
185
  def run(self, prompt: str, **kwargs) -> Image.Image:
 
 
 
 
 
 
 
 
 
186
  return self.pipe(prompt=prompt, **kwargs).images
187
 
188
 
189
  # ===== Chroma =====
190
  class ChromaLoader(BasePipelineLoader):
 
 
191
  def load(self):
 
 
 
 
 
192
  return ChromaPipeline.from_pretrained(
193
  "lodestones/Chroma", torch_dtype=torch.bfloat16
194
  ).to(self.device)
195
 
196
 
197
  class ChromaRunner(BasePipelineRunner):
 
 
198
  def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
199
  return self.pipe(
200
  prompt=prompt, negative_prompt=negative_prompt, **kwargs
201
  ).images
@@ -211,6 +333,22 @@ PIPELINE_REGISTRY = {
211
 
212
 
213
  def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  if name not in PIPELINE_REGISTRY:
215
  raise ValueError(f"Unsupported model: {name}")
216
  loader_cls, runner_cls = PIPELINE_REGISTRY[name]
 
38
 
39
 
40
  class BasePipelineLoader(ABC):
41
+ """Abstract base class for loading Hugging Face image generation pipelines.
42
+
43
+ Attributes:
44
+ device (str): Device to load the pipeline on.
45
+
46
+ Methods:
47
+ load(): Loads and returns the pipeline.
48
+ """
49
+
50
  def __init__(self, device="cuda"):
51
  self.device = device
52
 
53
  @abstractmethod
54
  def load(self):
55
+ """Load and return the pipeline instance."""
56
  pass
57
 
58
 
59
  class BasePipelineRunner(ABC):
60
+ """Abstract base class for running image generation pipelines.
61
+
62
+ Attributes:
63
+ pipe: The loaded pipeline.
64
+
65
+ Methods:
66
+ run(prompt, **kwargs): Runs the pipeline with a prompt.
67
+ """
68
+
69
  def __init__(self, pipe):
70
  self.pipe = pipe
71
 
72
  @abstractmethod
73
  def run(self, prompt: str, **kwargs) -> Image.Image:
74
+ """Run the pipeline with the given prompt.
75
+
76
+ Args:
77
+ prompt (str): Text prompt for image generation.
78
+ **kwargs: Additional pipeline arguments.
79
+
80
+ Returns:
81
+ Image.Image: Generated image(s).
82
+ """
83
  pass
84
 
85
 
86
  # ===== SD3.5-medium =====
87
  class SD35Loader(BasePipelineLoader):
88
+ """Loader for Stable Diffusion 3.5 medium pipeline."""
89
+
90
  def load(self):
91
+ """Load the Stable Diffusion 3.5 medium pipeline.
92
+
93
+ Returns:
94
+ StableDiffusion3Pipeline: Loaded pipeline.
95
+ """
96
  pipe = StableDiffusion3Pipeline.from_pretrained(
97
  "stabilityai/stable-diffusion-3.5-medium",
98
  torch_dtype=torch.float16,
 
105
 
106
 
107
  class SD35Runner(BasePipelineRunner):
108
+ """Runner for Stable Diffusion 3.5 medium pipeline."""
109
+
110
  def run(self, prompt: str, **kwargs) -> Image.Image:
111
+ """Generate images using Stable Diffusion 3.5 medium.
112
+
113
+ Args:
114
+ prompt (str): Text prompt.
115
+ **kwargs: Additional arguments.
116
+
117
+ Returns:
118
+ Image.Image: Generated image(s).
119
+ """
120
  return self.pipe(prompt=prompt, **kwargs).images
121
 
122
 
123
  # ===== Cosmos2 =====
124
  class CosmosLoader(BasePipelineLoader):
125
+ """Loader for Cosmos2 text-to-image pipeline."""
126
+
127
  def __init__(
128
  self,
129
  model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
 
135
  self.local_dir = local_dir
136
 
137
  def _patch(self):
138
+ """Patch model and processor for optimized loading."""
139
+
140
  def patch_model(cls):
141
  orig = cls.from_pretrained
142
 
 
160
  patch_processor(SiglipProcessor)
161
 
162
  def load(self):
163
+ """Load the Cosmos2 text-to-image pipeline.
164
+
165
+ Returns:
166
+ Cosmos2TextToImagePipeline: Loaded pipeline.
167
+ """
168
  self._patch()
169
  snapshot_download(
170
  repo_id=self.model_id,
 
196
 
197
 
198
  class CosmosRunner(BasePipelineRunner):
199
+ """Runner for Cosmos2 text-to-image pipeline."""
200
+
201
  def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
202
+ """Generate images using Cosmos2 pipeline.
203
+
204
+ Args:
205
+ prompt (str): Text prompt.
206
+ negative_prompt (str, optional): Negative prompt.
207
+ **kwargs: Additional arguments.
208
+
209
+ Returns:
210
+ Image.Image: Generated image(s).
211
+ """
212
  return self.pipe(
213
  prompt=prompt, negative_prompt=negative_prompt, **kwargs
214
  ).images
 
216
 
217
  # ===== Kolors =====
218
  class KolorsLoader(BasePipelineLoader):
219
+ """Loader for Kolors pipeline."""
220
+
221
  def load(self):
222
+ """Load the Kolors pipeline.
223
+
224
+ Returns:
225
+ KolorsPipeline: Loaded pipeline.
226
+ """
227
  pipe = KolorsPipeline.from_pretrained(
228
  "Kwai-Kolors/Kolors-diffusers",
229
  torch_dtype=torch.float16,
 
238
 
239
 
240
  class KolorsRunner(BasePipelineRunner):
241
+ """Runner for Kolors pipeline."""
242
+
243
  def run(self, prompt: str, **kwargs) -> Image.Image:
244
+ """Generate images using Kolors pipeline.
245
+
246
+ Args:
247
+ prompt (str): Text prompt.
248
+ **kwargs: Additional arguments.
249
+
250
+ Returns:
251
+ Image.Image: Generated image(s).
252
+ """
253
  return self.pipe(prompt=prompt, **kwargs).images
254
 
255
 
256
  # ===== Flux =====
257
  class FluxLoader(BasePipelineLoader):
258
+ """Loader for Flux pipeline."""
259
+
260
  def load(self):
261
+ """Load the Flux pipeline.
262
+
263
+ Returns:
264
+ FluxPipeline: Loaded pipeline.
265
+ """
266
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
267
  pipe = FluxPipeline.from_pretrained(
268
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
 
274
 
275
 
276
  class FluxRunner(BasePipelineRunner):
277
+ """Runner for Flux pipeline."""
278
+
279
  def run(self, prompt: str, **kwargs) -> Image.Image:
280
+ """Generate images using Flux pipeline.
281
+
282
+ Args:
283
+ prompt (str): Text prompt.
284
+ **kwargs: Additional arguments.
285
+
286
+ Returns:
287
+ Image.Image: Generated image(s).
288
+ """
289
  return self.pipe(prompt=prompt, **kwargs).images
290
 
291
 
292
  # ===== Chroma =====
293
  class ChromaLoader(BasePipelineLoader):
294
+ """Loader for Chroma pipeline."""
295
+
296
  def load(self):
297
+ """Load the Chroma pipeline.
298
+
299
+ Returns:
300
+ ChromaPipeline: Loaded pipeline.
301
+ """
302
  return ChromaPipeline.from_pretrained(
303
  "lodestones/Chroma", torch_dtype=torch.bfloat16
304
  ).to(self.device)
305
 
306
 
307
  class ChromaRunner(BasePipelineRunner):
308
+ """Runner for Chroma pipeline."""
309
+
310
  def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
311
+ """Generate images using Chroma pipeline.
312
+
313
+ Args:
314
+ prompt (str): Text prompt.
315
+ negative_prompt (str, optional): Negative prompt.
316
+ **kwargs: Additional arguments.
317
+
318
+ Returns:
319
+ Image.Image: Generated image(s).
320
+ """
321
  return self.pipe(
322
  prompt=prompt, negative_prompt=negative_prompt, **kwargs
323
  ).images
 
333
 
334
 
335
  def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
336
+ """Build a Hugging Face image generation pipeline runner by name.
337
+
338
+ Args:
339
+ name (str): Name of the pipeline (e.g., "sd35", "cosmos").
340
+ device (str): Device to load the pipeline on.
341
+
342
+ Returns:
343
+ BasePipelineRunner: Pipeline runner instance.
344
+
345
+ Example:
346
+ ```py
347
+ from embodied_gen.models.image_comm_model import build_hf_image_pipeline
348
+ runner = build_hf_image_pipeline("sd35")
349
+ images = runner.run(prompt="A robot holding a sign that says 'Hello'")
350
+ ```
351
+ """
352
  if name not in PIPELINE_REGISTRY:
353
  raise ValueError(f"Unsupported model: {name}")
354
  loader_cls, runner_cls = PIPELINE_REGISTRY[name]
embodied_gen/models/layout.py CHANGED
@@ -376,6 +376,21 @@ LAYOUT_DESCRIBER_PROMPT = """
376
 
377
 
378
  class LayoutDesigner(object):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def __init__(
380
  self,
381
  gpt_client: GPTclient,
@@ -387,6 +402,15 @@ class LayoutDesigner(object):
387
  self.gpt_client = gpt_client
388
 
389
  def query(self, prompt: str, params: dict = None) -> str:
 
 
 
 
 
 
 
 
 
390
  full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
391
 
392
  response = self.gpt_client.query(
@@ -400,6 +424,17 @@ class LayoutDesigner(object):
400
  return response
401
 
402
  def format_response(self, response: str) -> dict:
 
 
 
 
 
 
 
 
 
 
 
403
  cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
404
  try:
405
  output = json.loads(cleaned)
@@ -411,9 +446,23 @@ class LayoutDesigner(object):
411
  return output
412
 
413
  def format_response_repair(self, response: str) -> dict:
 
 
 
 
 
 
 
 
414
  return json_repair.loads(response)
415
 
416
  def save_output(self, output: dict, save_path: str) -> None:
 
 
 
 
 
 
417
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
418
  with open(save_path, 'w') as f:
419
  json.dump(output, f, indent=4)
@@ -421,6 +470,16 @@ class LayoutDesigner(object):
421
  def __call__(
422
  self, prompt: str, save_path: str = None, params: dict = None
423
  ) -> dict | str:
 
 
 
 
 
 
 
 
 
 
424
  response = self.query(prompt, params=params)
425
  output = self.format_response_repair(response)
426
  self.save_output(output, save_path) if save_path else None
@@ -442,6 +501,29 @@ LAYOUT_DESCRIBER = LayoutDesigner(
442
  def build_scene_layout(
443
  task_desc: str, output_path: str = None, gpt_params: dict = None
444
  ) -> LayoutInfo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
446
  layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
447
  object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
 
376
 
377
 
378
  class LayoutDesigner(object):
379
+ """A class for querying GPT-based scene layout reasoning and formatting responses.
380
+
381
+ Attributes:
382
+ prompt (str): The system prompt for GPT.
383
+ verbose (bool): Whether to log responses.
384
+ gpt_client (GPTclient): The GPT client instance.
385
+
386
+ Methods:
387
+ query(prompt, params): Query GPT with a prompt and parameters.
388
+ format_response(response): Parse and clean JSON response.
389
+ format_response_repair(response): Repair and parse JSON response.
390
+ save_output(output, save_path): Save output to file.
391
+ __call__(prompt, save_path, params): Query and process output.
392
+ """
393
+
394
  def __init__(
395
  self,
396
  gpt_client: GPTclient,
 
402
  self.gpt_client = gpt_client
403
 
404
  def query(self, prompt: str, params: dict = None) -> str:
405
+ """Query GPT with the system prompt and user prompt.
406
+
407
+ Args:
408
+ prompt (str): User prompt.
409
+ params (dict, optional): GPT parameters.
410
+
411
+ Returns:
412
+ str: GPT response.
413
+ """
414
  full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
415
 
416
  response = self.gpt_client.query(
 
424
  return response
425
 
426
  def format_response(self, response: str) -> dict:
427
+ """Format and parse GPT response as JSON.
428
+
429
+ Args:
430
+ response (str): Raw GPT response.
431
+
432
+ Returns:
433
+ dict: Parsed JSON output.
434
+
435
+ Raises:
436
+ json.JSONDecodeError: If parsing fails.
437
+ """
438
  cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
439
  try:
440
  output = json.loads(cleaned)
 
446
  return output
447
 
448
  def format_response_repair(self, response: str) -> dict:
449
+ """Repair and parse possibly broken JSON response.
450
+
451
+ Args:
452
+ response (str): Raw GPT response.
453
+
454
+ Returns:
455
+ dict: Parsed JSON output.
456
+ """
457
  return json_repair.loads(response)
458
 
459
  def save_output(self, output: dict, save_path: str) -> None:
460
+ """Save output dictionary to a file.
461
+
462
+ Args:
463
+ output (dict): Output data.
464
+ save_path (str): Path to save the file.
465
+ """
466
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
467
  with open(save_path, 'w') as f:
468
  json.dump(output, f, indent=4)
 
470
  def __call__(
471
  self, prompt: str, save_path: str = None, params: dict = None
472
  ) -> dict | str:
473
+ """Query GPT and process the output.
474
+
475
+ Args:
476
+ prompt (str): User prompt.
477
+ save_path (str, optional): Path to save output.
478
+ params (dict, optional): GPT parameters.
479
+
480
+ Returns:
481
+ dict | str: Output data.
482
+ """
483
  response = self.query(prompt, params=params)
484
  output = self.format_response_repair(response)
485
  self.save_output(output, save_path) if save_path else None
 
501
  def build_scene_layout(
502
  task_desc: str, output_path: str = None, gpt_params: dict = None
503
  ) -> LayoutInfo:
504
+ """Build a 3D scene layout from a natural language task description.
505
+
506
+ This function uses GPT-based reasoning to generate a structured scene layout,
507
+ including object hierarchy, spatial relations, and style descriptions.
508
+
509
+ Args:
510
+ task_desc (str): Natural language description of the robotic task.
511
+ output_path (str, optional): Path to save the visualized scene tree.
512
+ gpt_params (dict, optional): Parameters for GPT queries.
513
+
514
+ Returns:
515
+ LayoutInfo: Structured layout information for the scene.
516
+
517
+ Example:
518
+ ```py
519
+ from embodied_gen.models.layout import build_scene_layout
520
+ layout_info = build_scene_layout(
521
+ task_desc="Put the apples on the table on the plate",
522
+ output_path="outputs/scene_tree.jpg",
523
+ )
524
+ print(layout_info)
525
+ ```
526
+ """
527
  layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
528
  layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
529
  object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
embodied_gen/models/segment_model.py CHANGED
@@ -48,12 +48,19 @@ __all__ = [
48
 
49
 
50
  class SAMRemover(object):
51
- """Loading SAM models and performing background removal on images.
52
 
53
  Attributes:
54
  checkpoint (str): Path to the model checkpoint.
55
- model_type (str): Type of the SAM model to load (default: "vit_h").
56
- area_ratio (float): Area ratio filtering small connected components.
 
 
 
 
 
 
 
57
  """
58
 
59
  def __init__(
@@ -78,6 +85,14 @@ class SAMRemover(object):
78
  self.mask_generator = self._load_sam_model(checkpoint)
79
 
80
  def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
 
 
 
 
 
 
 
 
81
  sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
82
  sam.to(device=self.device)
83
 
@@ -89,13 +104,11 @@ class SAMRemover(object):
89
  """Removes the background from an image using the SAM model.
90
 
91
  Args:
92
- image (Union[str, Image.Image, np.ndarray]): Input image,
93
- can be a file path, PIL Image, or numpy array.
94
- save_path (str): Path to save the output image (default: None).
95
 
96
  Returns:
97
- Image.Image: The image with background removed,
98
- including an alpha channel.
99
  """
100
  # Convert input to numpy array
101
  if isinstance(image, str):
@@ -134,6 +147,15 @@ class SAMRemover(object):
134
 
135
 
136
  class SAMPredictor(object):
 
 
 
 
 
 
 
 
 
137
  def __init__(
138
  self,
139
  checkpoint: str = None,
@@ -157,12 +179,28 @@ class SAMPredictor(object):
157
  self.binary_thresh = binary_thresh
158
 
159
  def _load_sam_model(self, checkpoint: str) -> SamPredictor:
 
 
 
 
 
 
 
 
160
  sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
161
  sam.to(device=self.device)
162
 
163
  return SamPredictor(sam)
164
 
165
  def preprocess_image(self, image: Image.Image) -> np.ndarray:
 
 
 
 
 
 
 
 
166
  if isinstance(image, str):
167
  image = Image.open(image)
168
  elif isinstance(image, np.ndarray):
@@ -178,6 +216,15 @@ class SAMPredictor(object):
178
  image: np.ndarray,
179
  selected_points: list[list[int]],
180
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
181
  if len(selected_points) == 0:
182
  return []
183
 
@@ -220,6 +267,15 @@ class SAMPredictor(object):
220
  def get_segmented_image(
221
  self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
222
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
223
  seg_image = Image.fromarray(image, mode="RGB")
224
  alpha_channel = np.zeros(
225
  (seg_image.height, seg_image.width), dtype=np.uint8
@@ -241,6 +297,15 @@ class SAMPredictor(object):
241
  image: Union[str, Image.Image, np.ndarray],
242
  selected_points: list[list[int]],
243
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
244
  image = self.preprocess_image(image)
245
  self.predictor.set_image(image)
246
  masks = self.generate_masks(image, selected_points)
@@ -249,12 +314,32 @@ class SAMPredictor(object):
249
 
250
 
251
  class RembgRemover(object):
 
 
 
 
 
 
 
 
 
 
252
  def __init__(self):
 
253
  self.rembg_session = rembg.new_session("u2net")
254
 
255
  def __call__(
256
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
257
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
258
  if isinstance(image, str):
259
  image = Image.open(image)
260
  elif isinstance(image, np.ndarray):
@@ -271,7 +356,18 @@ class RembgRemover(object):
271
 
272
 
273
  class BMGG14Remover(object):
 
 
 
 
 
 
 
 
 
 
274
  def __init__(self) -> None:
 
275
  self.model = pipeline(
276
  "image-segmentation",
277
  model="briaai/RMBG-1.4",
@@ -281,6 +377,15 @@ class BMGG14Remover(object):
281
  def __call__(
282
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
283
  ):
 
 
 
 
 
 
 
 
 
284
  if isinstance(image, str):
285
  image = Image.open(image)
286
  elif isinstance(image, np.ndarray):
@@ -299,6 +404,16 @@ class BMGG14Remover(object):
299
  def invert_rgba_pil(
300
  image: Image.Image, mask: Image.Image, save_path: str = None
301
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
302
  mask = (255 - np.array(mask))[..., None]
303
  image_array = np.concatenate([np.array(image), mask], axis=-1)
304
  inverted_image = Image.fromarray(image_array, "RGBA")
@@ -318,6 +433,20 @@ def get_segmented_image_by_agent(
318
  save_path: str = None,
319
  mode: Literal["loose", "strict"] = "loose",
320
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
322
  if seg_checker is None:
323
  return True
 
48
 
49
 
50
  class SAMRemover(object):
51
+ """Loads SAM models and performs background removal on images.
52
 
53
  Attributes:
54
  checkpoint (str): Path to the model checkpoint.
55
+ model_type (str): Type of the SAM model to load.
56
+ area_ratio (float): Area ratio for filtering small connected components.
57
+
58
+ Example:
59
+ ```py
60
+ from embodied_gen.models.segment_model import SAMRemover
61
+ remover = SAMRemover(model_type="vit_h")
62
+ result = remover("input.jpg", "output.png")
63
+ ```
64
  """
65
 
66
  def __init__(
 
85
  self.mask_generator = self._load_sam_model(checkpoint)
86
 
87
  def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
88
+ """Loads the SAM model and returns a mask generator.
89
+
90
+ Args:
91
+ checkpoint (str): Path to model checkpoint.
92
+
93
+ Returns:
94
+ SamAutomaticMaskGenerator: Mask generator instance.
95
+ """
96
  sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
97
  sam.to(device=self.device)
98
 
 
104
  """Removes the background from an image using the SAM model.
105
 
106
  Args:
107
+ image (Union[str, Image.Image, np.ndarray]): Input image.
108
+ save_path (str, optional): Path to save the output image.
 
109
 
110
  Returns:
111
+ Image.Image: Image with background removed (RGBA).
 
112
  """
113
  # Convert input to numpy array
114
  if isinstance(image, str):
 
147
 
148
 
149
  class SAMPredictor(object):
150
+ """Loads SAM models and predicts segmentation masks from user points.
151
+
152
+ Args:
153
+ checkpoint (str, optional): Path to model checkpoint.
154
+ model_type (str, optional): SAM model type.
155
+ binary_thresh (float, optional): Threshold for binary mask.
156
+ device (str, optional): Device for inference.
157
+ """
158
+
159
  def __init__(
160
  self,
161
  checkpoint: str = None,
 
179
  self.binary_thresh = binary_thresh
180
 
181
  def _load_sam_model(self, checkpoint: str) -> SamPredictor:
182
+ """Loads the SAM model and returns a predictor.
183
+
184
+ Args:
185
+ checkpoint (str): Path to model checkpoint.
186
+
187
+ Returns:
188
+ SamPredictor: Predictor instance.
189
+ """
190
  sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
191
  sam.to(device=self.device)
192
 
193
  return SamPredictor(sam)
194
 
195
  def preprocess_image(self, image: Image.Image) -> np.ndarray:
196
+ """Preprocesses input image for SAM prediction.
197
+
198
+ Args:
199
+ image (Image.Image): Input image.
200
+
201
+ Returns:
202
+ np.ndarray: Preprocessed image array.
203
+ """
204
  if isinstance(image, str):
205
  image = Image.open(image)
206
  elif isinstance(image, np.ndarray):
 
216
  image: np.ndarray,
217
  selected_points: list[list[int]],
218
  ) -> np.ndarray:
219
+ """Generates segmentation masks from selected points.
220
+
221
+ Args:
222
+ image (np.ndarray): Input image array.
223
+ selected_points (list[list[int]]): List of points and labels.
224
+
225
+ Returns:
226
+ list[tuple[np.ndarray, str]]: List of masks and names.
227
+ """
228
  if len(selected_points) == 0:
229
  return []
230
 
 
267
  def get_segmented_image(
268
  self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
269
  ) -> Image.Image:
270
+ """Combines masks and returns segmented image with alpha channel.
271
+
272
+ Args:
273
+ image (np.ndarray): Input image array.
274
+ masks (list[tuple[np.ndarray, str]]): List of masks.
275
+
276
+ Returns:
277
+ Image.Image: Segmented RGBA image.
278
+ """
279
  seg_image = Image.fromarray(image, mode="RGB")
280
  alpha_channel = np.zeros(
281
  (seg_image.height, seg_image.width), dtype=np.uint8
 
297
  image: Union[str, Image.Image, np.ndarray],
298
  selected_points: list[list[int]],
299
  ) -> Image.Image:
300
+ """Segments image using selected points.
301
+
302
+ Args:
303
+ image (Union[str, Image.Image, np.ndarray]): Input image.
304
+ selected_points (list[list[int]]): List of points and labels.
305
+
306
+ Returns:
307
+ Image.Image: Segmented RGBA image.
308
+ """
309
  image = self.preprocess_image(image)
310
  self.predictor.set_image(image)
311
  masks = self.generate_masks(image, selected_points)
 
314
 
315
 
316
  class RembgRemover(object):
317
+ """Removes background from images using the rembg library.
318
+
319
+ Example:
320
+ ```py
321
+ from embodied_gen.models.segment_model import RembgRemover
322
+ remover = RembgRemover()
323
+ result = remover("input.jpg", "output.png")
324
+ ```
325
+ """
326
+
327
  def __init__(self):
328
+ """Initializes the RembgRemover."""
329
  self.rembg_session = rembg.new_session("u2net")
330
 
331
  def __call__(
332
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
333
  ) -> Image.Image:
334
+ """Removes background from an image.
335
+
336
+ Args:
337
+ image (Union[str, Image.Image, np.ndarray]): Input image.
338
+ save_path (str, optional): Path to save the output image.
339
+
340
+ Returns:
341
+ Image.Image: Image with background removed (RGBA).
342
+ """
343
  if isinstance(image, str):
344
  image = Image.open(image)
345
  elif isinstance(image, np.ndarray):
 
356
 
357
 
358
  class BMGG14Remover(object):
359
+ """Removes background using the RMBG-1.4 segmentation model.
360
+
361
+ Example:
362
+ ```py
363
+ from embodied_gen.models.segment_model import BMGG14Remover
364
+ remover = BMGG14Remover()
365
+ result = remover("input.jpg", "output.png")
366
+ ```
367
+ """
368
+
369
  def __init__(self) -> None:
370
+ """Initializes the BMGG14Remover."""
371
  self.model = pipeline(
372
  "image-segmentation",
373
  model="briaai/RMBG-1.4",
 
377
  def __call__(
378
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
379
  ):
380
+ """Removes background from an image.
381
+
382
+ Args:
383
+ image (Union[str, Image.Image, np.ndarray]): Input image.
384
+ save_path (str, optional): Path to save the output image.
385
+
386
+ Returns:
387
+ Image.Image: Image with background removed.
388
+ """
389
  if isinstance(image, str):
390
  image = Image.open(image)
391
  elif isinstance(image, np.ndarray):
 
404
  def invert_rgba_pil(
405
  image: Image.Image, mask: Image.Image, save_path: str = None
406
  ) -> Image.Image:
407
+ """Inverts the alpha channel of an RGBA image using a mask.
408
+
409
+ Args:
410
+ image (Image.Image): Input RGB image.
411
+ mask (Image.Image): Mask image for alpha inversion.
412
+ save_path (str, optional): Path to save the output image.
413
+
414
+ Returns:
415
+ Image.Image: RGBA image with inverted alpha.
416
+ """
417
  mask = (255 - np.array(mask))[..., None]
418
  image_array = np.concatenate([np.array(image), mask], axis=-1)
419
  inverted_image = Image.fromarray(image_array, "RGBA")
 
433
  save_path: str = None,
434
  mode: Literal["loose", "strict"] = "loose",
435
  ) -> Image.Image:
436
+ """Segments an image using SAM and rembg, with quality checking.
437
+
438
+ Args:
439
+ image (Image.Image): Input image.
440
+ sam_remover (SAMRemover): SAM-based remover.
441
+ rbg_remover (RembgRemover): rembg-based remover.
442
+ seg_checker (ImageSegChecker, optional): Quality checker.
443
+ save_path (str, optional): Path to save the output image.
444
+ mode (Literal["loose", "strict"], optional): Segmentation mode.
445
+
446
+ Returns:
447
+ Image.Image: Segmented RGBA image.
448
+ """
449
+
450
  def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
451
  if seg_checker is None:
452
  return True
embodied_gen/models/sr_model.py CHANGED
@@ -39,13 +39,38 @@ __all__ = [
39
 
40
 
41
  class ImageStableSR:
42
- """Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def __init__(
45
  self,
46
  model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
47
  device="cuda",
48
  ) -> None:
 
 
 
 
 
 
49
  from diffusers import StableDiffusionUpscalePipeline
50
 
51
  self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
@@ -62,6 +87,16 @@ class ImageStableSR:
62
  prompt: str = "",
63
  infer_step: int = 20,
64
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
65
  if isinstance(image, np.ndarray):
66
  image = Image.fromarray(image)
67
 
@@ -86,9 +121,26 @@ class ImageRealESRGAN:
86
  Attributes:
87
  outscale (int): The output image scale factor (e.g., 2, 4).
88
  model_path (str): Path to the pre-trained model weights.
 
 
 
 
 
 
 
 
 
 
 
89
  """
90
 
91
  def __init__(self, outscale: int, model_path: str = None) -> None:
 
 
 
 
 
 
92
  # monkey patch to support torchvision>=0.16
93
  import torchvision
94
  from packaging import version
@@ -122,6 +174,7 @@ class ImageRealESRGAN:
122
  self.model_path = model_path
123
 
124
  def _lazy_init(self):
 
125
  if self.upsampler is None:
126
  from basicsr.archs.rrdbnet_arch import RRDBNet
127
  from realesrgan import RealESRGANer
@@ -145,6 +198,14 @@ class ImageRealESRGAN:
145
 
146
  @spaces.GPU
147
  def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
 
 
 
 
 
 
 
 
148
  self._lazy_init()
149
 
150
  if isinstance(image, Image.Image):
 
39
 
40
 
41
  class ImageStableSR:
42
+ """Super-resolution image upscaler using Stable Diffusion x4 upscaling model.
43
+
44
+ This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality
45
+ image super-resolution.
46
+
47
+ Args:
48
+ model_path (str, optional): Path or HuggingFace repo for the model.
49
+ device (str, optional): Device for inference.
50
+
51
+ Example:
52
+ ```py
53
+ from embodied_gen.models.sr_model import ImageStableSR
54
+ from PIL import Image
55
+
56
+ sr_model = ImageStableSR()
57
+ img = Image.open("input.png")
58
+ upscaled = sr_model(img)
59
+ upscaled.save("output.png")
60
+ ```
61
+ """
62
 
63
  def __init__(
64
  self,
65
  model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
66
  device="cuda",
67
  ) -> None:
68
+ """Initializes the Stable Diffusion x4 upscaler.
69
+
70
+ Args:
71
+ model_path (str, optional): Model path or repo.
72
+ device (str, optional): Device for inference.
73
+ """
74
  from diffusers import StableDiffusionUpscalePipeline
75
 
76
  self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
 
87
  prompt: str = "",
88
  infer_step: int = 20,
89
  ) -> Image.Image:
90
+ """Performs super-resolution on the input image.
91
+
92
+ Args:
93
+ image (Union[Image.Image, np.ndarray]): Input image.
94
+ prompt (str, optional): Text prompt for upscaling.
95
+ infer_step (int, optional): Number of inference steps.
96
+
97
+ Returns:
98
+ Image.Image: Upscaled image.
99
+ """
100
  if isinstance(image, np.ndarray):
101
  image = Image.fromarray(image)
102
 
 
121
  Attributes:
122
  outscale (int): The output image scale factor (e.g., 2, 4).
123
  model_path (str): Path to the pre-trained model weights.
124
+
125
+ Example:
126
+ ```py
127
+ from embodied_gen.models.sr_model import ImageRealESRGAN
128
+ from PIL import Image
129
+
130
+ sr_model = ImageRealESRGAN(outscale=4)
131
+ img = Image.open("input.png")
132
+ upscaled = sr_model(img)
133
+ upscaled.save("output.png")
134
+ ```
135
  """
136
 
137
  def __init__(self, outscale: int, model_path: str = None) -> None:
138
+ """Initializes the RealESRGAN upscaler.
139
+
140
+ Args:
141
+ outscale (int): Output scale factor.
142
+ model_path (str, optional): Path to model weights.
143
+ """
144
  # monkey patch to support torchvision>=0.16
145
  import torchvision
146
  from packaging import version
 
174
  self.model_path = model_path
175
 
176
  def _lazy_init(self):
177
+ """Lazily initializes the RealESRGAN model."""
178
  if self.upsampler is None:
179
  from basicsr.archs.rrdbnet_arch import RRDBNet
180
  from realesrgan import RealESRGANer
 
198
 
199
  @spaces.GPU
200
  def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
201
+ """Performs super-resolution on the input image.
202
+
203
+ Args:
204
+ image (Union[Image.Image, np.ndarray]): Input image.
205
+
206
+ Returns:
207
+ Image.Image: Upscaled image.
208
+ """
209
  self._lazy_init()
210
 
211
  if isinstance(image, Image.Image):
embodied_gen/models/text_model.py CHANGED
@@ -60,6 +60,11 @@ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background,
60
 
61
 
62
  def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
 
 
 
 
 
63
  logger.info(f"Download kolors weights from huggingface...")
64
  os.makedirs(local_dir, exist_ok=True)
65
  subprocess.run(
@@ -93,6 +98,22 @@ def build_text2img_ip_pipeline(
93
  ref_scale: float,
94
  device: str = "cuda",
95
  ) -> StableDiffusionXLPipelineIP:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  download_kolors_weights(ckpt_dir)
97
 
98
  text_encoder = ChatGLMModel.from_pretrained(
@@ -146,6 +167,21 @@ def build_text2img_pipeline(
146
  ckpt_dir: str,
147
  device: str = "cuda",
148
  ) -> StableDiffusionXLPipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  download_kolors_weights(ckpt_dir)
150
 
151
  text_encoder = ChatGLMModel.from_pretrained(
@@ -185,6 +221,29 @@ def text2img_gen(
185
  ip_image_size: int = 512,
186
  seed: int = None,
187
  ) -> list[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  prompt = PROMPT_KAPPEND.format(object=prompt.strip())
189
  logger.info(f"Processing prompt: {prompt}")
190
 
 
60
 
61
 
62
  def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
63
+ """Downloads Kolors model weights from HuggingFace.
64
+
65
+ Args:
66
+ local_dir (str, optional): Local directory to store weights.
67
+ """
68
  logger.info(f"Download kolors weights from huggingface...")
69
  os.makedirs(local_dir, exist_ok=True)
70
  subprocess.run(
 
98
  ref_scale: float,
99
  device: str = "cuda",
100
  ) -> StableDiffusionXLPipelineIP:
101
+ """Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.
102
+
103
+ Args:
104
+ ckpt_dir (str): Directory containing model checkpoints.
105
+ ref_scale (float): Reference scale for IP-Adapter.
106
+ device (str, optional): Device for inference.
107
+
108
+ Returns:
109
+ StableDiffusionXLPipelineIP: Configured pipeline.
110
+
111
+ Example:
112
+ ```py
113
+ from embodied_gen.models.text_model import build_text2img_ip_pipeline
114
+ pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
115
+ ```
116
+ """
117
  download_kolors_weights(ckpt_dir)
118
 
119
  text_encoder = ChatGLMModel.from_pretrained(
 
167
  ckpt_dir: str,
168
  device: str = "cuda",
169
  ) -> StableDiffusionXLPipeline:
170
+ """Builds a Stable Diffusion XL pipeline for text-to-image generation.
171
+
172
+ Args:
173
+ ckpt_dir (str): Directory containing model checkpoints.
174
+ device (str, optional): Device for inference.
175
+
176
+ Returns:
177
+ StableDiffusionXLPipeline: Configured pipeline.
178
+
179
+ Example:
180
+ ```py
181
+ from embodied_gen.models.text_model import build_text2img_pipeline
182
+ pipe = build_text2img_pipeline("weights/Kolors")
183
+ ```
184
+ """
185
  download_kolors_weights(ckpt_dir)
186
 
187
  text_encoder = ChatGLMModel.from_pretrained(
 
221
  ip_image_size: int = 512,
222
  seed: int = None,
223
  ) -> list[Image.Image]:
224
+ """Generates images from text prompts using a Stable Diffusion XL pipeline.
225
+
226
+ Args:
227
+ prompt (str): Text prompt for image generation.
228
+ n_sample (int): Number of images to generate.
229
+ guidance_scale (float): Guidance scale for diffusion.
230
+ pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance.
231
+ ip_image (Image.Image | str, optional): Reference image for IP-Adapter.
232
+ image_wh (tuple[int, int], optional): Output image size (width, height).
233
+ infer_step (int, optional): Number of inference steps.
234
+ ip_image_size (int, optional): Size for IP-Adapter image.
235
+ seed (int, optional): Random seed.
236
+
237
+ Returns:
238
+ list[Image.Image]: List of generated images.
239
+
240
+ Example:
241
+ ```py
242
+ from embodied_gen.models.text_model import text2img_gen
243
+ images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
244
+ images[0].save("banana.png")
245
+ ```
246
+ """
247
  prompt = PROMPT_KAPPEND.format(object=prompt.strip())
248
  logger.info(f"Processing prompt: {prompt}")
249
 
embodied_gen/models/texture_model.py CHANGED
@@ -42,6 +42,56 @@ def build_texture_gen_pipe(
42
  ip_adapt_scale: float = 0,
43
  device: str = "cuda",
44
  ) -> DiffusionPipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  download_kolors_weights(f"{base_ckpt_dir}/Kolors")
46
  logger.info(f"Load Kolors weights...")
47
  tokenizer = ChatGLMTokenizer.from_pretrained(
 
42
  ip_adapt_scale: float = 0,
43
  device: str = "cuda",
44
  ) -> DiffusionPipeline:
45
+ """Build and initialize the Kolors + ControlNet (optional IP-Adapter) texture generation pipeline.
46
+
47
+ Loads Kolors tokenizer, text encoder (ChatGLM), VAE, UNet, scheduler and (optionally)
48
+ a ControlNet checkpoint plus IP-Adapter vision encoder. If ``controlnet_ckpt`` is
49
+ not provided, the default multi-view texture ControlNet weights are downloaded
50
+ automatically from the hub. When ``ip_adapt_scale > 0`` an IP-Adapter vision
51
+ encoder and its weights are also loaded and activated.
52
+
53
+ Args:
54
+ base_ckpt_dir (str):
55
+ Root directory where Kolors (and optionally Kolors-IP-Adapter-Plus) weights
56
+ are or will be stored. Required subfolders: ``Kolors/{text_encoder,vae,unet,scheduler}``.
57
+ controlnet_ckpt (str, optional):
58
+ Directory containing a ControlNet checkpoint (safetensors). If ``None``,
59
+ downloads the default ``texture_gen_mv_v1`` snapshot.
60
+ ip_adapt_scale (float, optional):
61
+ Strength (>=0) of IP-Adapter conditioning. Set >0 to enable IP-Adapter;
62
+ typical values: 0.4-0.8. Default: 0 (disabled).
63
+ device (str, optional):
64
+ Target device to move the pipeline to (e.g. ``"cuda"``, ``"cuda:0"``, ``"cpu"``).
65
+ Default: ``"cuda"``.
66
+
67
+ Returns:
68
+ DiffusionPipeline: A configured
69
+ ``StableDiffusionXLControlNetImg2ImgPipeline`` ready for multi-view texture
70
+ generation (with optional IP-Adapter support).
71
+
72
+ Example:
73
+ Initialize pipeline with IP-Adapter enabled.
74
+ ```python
75
+ from embodied_gen.models.texture_model import build_texture_gen_pipe
76
+ ip_adapt_scale = 0.7
77
+ PIPELINE = build_texture_gen_pipe(
78
+ base_ckpt_dir="./weights",
79
+ ip_adapt_scale=ip_adapt_scale,
80
+ device="cuda",
81
+ )
82
+ PIPELINE.set_ip_adapter_scale([ip_adapt_scale])
83
+ ```
84
+ Initialize pipeline without IP-Adapter.
85
+ ```python
86
+ from embodied_gen.models.texture_model import build_texture_gen_pipe
87
+ PIPELINE = build_texture_gen_pipe(
88
+ base_ckpt_dir="./weights",
89
+ ip_adapt_scale=0,
90
+ device="cuda",
91
+ )
92
+ ```
93
+ """
94
+
95
  download_kolors_weights(f"{base_ckpt_dir}/Kolors")
96
  logger.info(f"Load Kolors weights...")
97
  tokenizer = ChatGLMTokenizer.from_pretrained(
embodied_gen/scripts/render_gs.py CHANGED
@@ -29,7 +29,7 @@ from embodied_gen.data.utils import (
29
  init_kal_camera,
30
  normalize_vertices_array,
31
  )
32
- from embodied_gen.models.gs_model import GaussianOperator
33
  from embodied_gen.utils.process_media import combine_images_to_grid
34
 
35
  logging.basicConfig(
@@ -97,21 +97,6 @@ def parse_args():
97
  return args
98
 
99
 
100
- def load_gs_model(
101
- input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
102
- ) -> GaussianOperator:
103
- gs_model = GaussianOperator.load_from_ply(input_gs)
104
- # Normalize vertices to [-1, 1], center to (0, 0, 0).
105
- _, scale, center = normalize_vertices_array(gs_model._means)
106
- scale, center = float(scale), center.tolist()
107
- transpose = [*[v for v in center], *pre_quat]
108
- instance_pose = torch.tensor(transpose).to(gs_model.device)
109
- gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
110
- gs_model.rescale(scale)
111
-
112
- return gs_model
113
-
114
-
115
  @spaces.GPU
116
  def entrypoint(**kwargs) -> None:
117
  args = parse_args()
 
29
  init_kal_camera,
30
  normalize_vertices_array,
31
  )
32
+ from embodied_gen.models.gs_model import load_gs_model
33
  from embodied_gen.utils.process_media import combine_images_to_grid
34
 
35
  logging.basicConfig(
 
97
  return args
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @spaces.GPU
101
  def entrypoint(**kwargs) -> None:
102
  args = parse_args()
embodied_gen/trainer/pono2mesh_trainer.py CHANGED
@@ -53,26 +53,31 @@ from thirdparty.pano2room.utils.functions import (
53
 
54
 
55
  class Pano2MeshSRPipeline:
56
- """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
57
 
58
- This class integrates several key components including:
59
- - Depth estimation from RGB panorama
60
- - Inpainting of missing regions under offsets
61
- - RGB-D to mesh conversion
62
- - Multi-view mesh repair
63
- - 3D Gaussian Splatting (3DGS) dataset generation
64
 
65
  Args:
66
  config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
67
 
68
  Example:
69
- ```python
 
 
 
 
70
  pipeline = Pano2MeshSRPipeline(config)
71
  pipeline(pano_image='example.png', output_dir='./output')
72
  ```
73
  """
74
 
75
  def __init__(self, config: Pano2MeshSRConfig) -> None:
 
 
 
 
 
76
  self.cfg = config
77
  self.device = config.device
78
 
@@ -93,6 +98,7 @@ class Pano2MeshSRPipeline:
93
  self.kernel = torch.from_numpy(kernel).float().to(self.device)
94
 
95
  def init_mesh_params(self) -> None:
 
96
  torch.set_default_device(self.device)
97
  self.inpaint_mask = torch.ones(
98
  (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
@@ -103,6 +109,14 @@ class Pano2MeshSRPipeline:
103
 
104
  @staticmethod
105
  def read_camera_pose_file(filepath: str) -> np.ndarray:
 
 
 
 
 
 
 
 
106
  with open(filepath, "r") as f:
107
  values = [float(num) for line in f for num in line.split()]
108
 
@@ -111,6 +125,14 @@ class Pano2MeshSRPipeline:
111
  def load_camera_poses(
112
  self, trajectory_dir: str
113
  ) -> tuple[np.ndarray, list[torch.Tensor]]:
 
 
 
 
 
 
 
 
114
  pose_filenames = sorted(
115
  [
116
  fname
@@ -148,6 +170,14 @@ class Pano2MeshSRPipeline:
148
  def load_inpaint_poses(
149
  self, poses: torch.Tensor
150
  ) -> dict[int, torch.Tensor]:
 
 
 
 
 
 
 
 
151
  inpaint_poses = dict()
152
  sampled_views = poses[:: self.cfg.inpaint_frame_stride]
153
  init_pose = torch.eye(4)
@@ -162,6 +192,14 @@ class Pano2MeshSRPipeline:
162
  return inpaint_poses
163
 
164
  def project(self, world_to_cam: torch.Tensor):
 
 
 
 
 
 
 
 
165
  (
166
  project_image,
167
  project_depth,
@@ -185,6 +223,14 @@ class Pano2MeshSRPipeline:
185
  return project_image[:3, ...], inpaint_mask, project_depth
186
 
187
  def render_pano(self, pose: torch.Tensor):
 
 
 
 
 
 
 
 
188
  cubemap_list = []
189
  for cubemap_pose in self.cubemap_w2cs:
190
  project_pose = cubemap_pose @ pose
@@ -213,6 +259,15 @@ class Pano2MeshSRPipeline:
213
  world_to_cam: torch.Tensor = None,
214
  using_distance_map: bool = True,
215
  ) -> None:
 
 
 
 
 
 
 
 
 
216
  if world_to_cam is None:
217
  world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
218
 
@@ -239,6 +294,15 @@ class Pano2MeshSRPipeline:
239
  def get_edge_image_by_depth(
240
  self, depth: torch.Tensor, dilate_iter: int = 1
241
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
242
  if isinstance(depth, torch.Tensor):
243
  depth = depth.cpu().detach().numpy()
244
 
@@ -253,6 +317,15 @@ class Pano2MeshSRPipeline:
253
  def mesh_repair_by_greedy_view_selection(
254
  self, pose_dict: dict[str, torch.Tensor], output_dir: str
255
  ) -> list:
 
 
 
 
 
 
 
 
 
256
  inpainted_panos_w_pose = []
257
  while len(pose_dict) > 0:
258
  logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
@@ -343,6 +416,17 @@ class Pano2MeshSRPipeline:
343
  distances: torch.Tensor,
344
  pano_mask: torch.Tensor,
345
  ) -> tuple[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
346
  mask = (pano_mask[None, ..., None] > 0.5).float()
347
  mask = mask.permute(0, 3, 1, 2)
348
  mask = dilation(mask, kernel=self.kernel)
@@ -364,6 +448,14 @@ class Pano2MeshSRPipeline:
364
  def preprocess_pano(
365
  self, image: Image.Image | str
366
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
367
  if isinstance(image, str):
368
  image = Image.open(image)
369
 
@@ -387,6 +479,17 @@ class Pano2MeshSRPipeline:
387
  def pano_to_perpective(
388
  self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
389
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
390
  rots = dict(
391
  roll=0,
392
  pitch=pitch,
@@ -404,6 +507,14 @@ class Pano2MeshSRPipeline:
404
  return perspective
405
 
406
  def pano_to_cubemap(self, pano_rgb: torch.Tensor):
 
 
 
 
 
 
 
 
407
  # Define six canonical cube directions in (pitch, yaw)
408
  directions = [
409
  (0, 0),
@@ -424,6 +535,11 @@ class Pano2MeshSRPipeline:
424
  return cubemaps_rgb
425
 
426
  def save_mesh(self, output_path: str) -> None:
 
 
 
 
 
427
  vertices_np = self.vertices.T.cpu().numpy()
428
  colors_np = self.colors.T.cpu().numpy()
429
  faces_np = self.faces.T.cpu().numpy()
@@ -434,6 +550,14 @@ class Pano2MeshSRPipeline:
434
  mesh.export(output_path)
435
 
436
  def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
 
 
 
 
 
 
 
 
437
  pose = mesh_pose.clone()
438
  pose[0, :] *= -1
439
  pose[1, :] *= -1
@@ -450,6 +574,15 @@ class Pano2MeshSRPipeline:
450
  return c2w
451
 
452
  def __call__(self, pano_image: Image.Image | str, output_dir: str):
 
 
 
 
 
 
 
 
 
453
  self.init_mesh_params()
454
  pano_rgb, pano_depth = self.preprocess_pano(pano_image)
455
  self.sup_pool = SupInfoPool()
 
53
 
54
 
55
  class Pano2MeshSRPipeline:
56
+ """Pipeline for converting panoramic RGB images into 3D mesh representations.
57
 
58
+ This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair,
59
+ and 3D Gaussian Splatting (3DGS) dataset generation.
 
 
 
 
60
 
61
  Args:
62
  config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
63
 
64
  Example:
65
+ ```py
66
+ from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
67
+ from embodied_gen.utils.config import Pano2MeshSRConfig
68
+
69
+ config = Pano2MeshSRConfig()
70
  pipeline = Pano2MeshSRPipeline(config)
71
  pipeline(pano_image='example.png', output_dir='./output')
72
  ```
73
  """
74
 
75
  def __init__(self, config: Pano2MeshSRConfig) -> None:
76
+ """Initializes the pipeline with models and camera poses.
77
+
78
+ Args:
79
+ config (Pano2MeshSRConfig): Configuration object.
80
+ """
81
  self.cfg = config
82
  self.device = config.device
83
 
 
98
  self.kernel = torch.from_numpy(kernel).float().to(self.device)
99
 
100
  def init_mesh_params(self) -> None:
101
+ """Initializes mesh parameters and inpaint mask."""
102
  torch.set_default_device(self.device)
103
  self.inpaint_mask = torch.ones(
104
  (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
 
109
 
110
  @staticmethod
111
  def read_camera_pose_file(filepath: str) -> np.ndarray:
112
+ """Reads a camera pose file and returns the pose matrix.
113
+
114
+ Args:
115
+ filepath (str): Path to the camera pose file.
116
+
117
+ Returns:
118
+ np.ndarray: 4x4 camera pose matrix.
119
+ """
120
  with open(filepath, "r") as f:
121
  values = [float(num) for line in f for num in line.split()]
122
 
 
125
  def load_camera_poses(
126
  self, trajectory_dir: str
127
  ) -> tuple[np.ndarray, list[torch.Tensor]]:
128
+ """Loads camera poses from a directory.
129
+
130
+ Args:
131
+ trajectory_dir (str): Directory containing camera pose files.
132
+
133
+ Returns:
134
+ tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses.
135
+ """
136
  pose_filenames = sorted(
137
  [
138
  fname
 
170
  def load_inpaint_poses(
171
  self, poses: torch.Tensor
172
  ) -> dict[int, torch.Tensor]:
173
+ """Samples and loads poses for inpainting.
174
+
175
+ Args:
176
+ poses (torch.Tensor): Tensor of camera poses.
177
+
178
+ Returns:
179
+ dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors.
180
+ """
181
  inpaint_poses = dict()
182
  sampled_views = poses[:: self.cfg.inpaint_frame_stride]
183
  init_pose = torch.eye(4)
 
192
  return inpaint_poses
193
 
194
  def project(self, world_to_cam: torch.Tensor):
195
+ """Projects the mesh to an image using the given camera pose.
196
+
197
+ Args:
198
+ world_to_cam (torch.Tensor): World-to-camera transformation matrix.
199
+
200
+ Returns:
201
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map.
202
+ """
203
  (
204
  project_image,
205
  project_depth,
 
223
  return project_image[:3, ...], inpaint_mask, project_depth
224
 
225
  def render_pano(self, pose: torch.Tensor):
226
+ """Renders a panorama from the mesh using the given pose.
227
+
228
+ Args:
229
+ pose (torch.Tensor): Camera pose.
230
+
231
+ Returns:
232
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask.
233
+ """
234
  cubemap_list = []
235
  for cubemap_pose in self.cubemap_w2cs:
236
  project_pose = cubemap_pose @ pose
 
259
  world_to_cam: torch.Tensor = None,
260
  using_distance_map: bool = True,
261
  ) -> None:
262
+ """Converts RGB-D images to mesh and updates mesh parameters.
263
+
264
+ Args:
265
+ rgb (torch.Tensor): RGB image tensor.
266
+ depth (torch.Tensor): Depth map tensor.
267
+ inpaint_mask (torch.Tensor): Inpaint mask tensor.
268
+ world_to_cam (torch.Tensor, optional): Camera pose.
269
+ using_distance_map (bool, optional): Whether to use distance map.
270
+ """
271
  if world_to_cam is None:
272
  world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
273
 
 
294
  def get_edge_image_by_depth(
295
  self, depth: torch.Tensor, dilate_iter: int = 1
296
  ) -> np.ndarray:
297
+ """Computes edge image from depth map.
298
+
299
+ Args:
300
+ depth (torch.Tensor): Depth map tensor.
301
+ dilate_iter (int, optional): Number of dilation iterations.
302
+
303
+ Returns:
304
+ np.ndarray: Edge image.
305
+ """
306
  if isinstance(depth, torch.Tensor):
307
  depth = depth.cpu().detach().numpy()
308
 
 
317
  def mesh_repair_by_greedy_view_selection(
318
  self, pose_dict: dict[str, torch.Tensor], output_dir: str
319
  ) -> list:
320
+ """Repairs mesh by selecting views greedily and inpainting missing regions.
321
+
322
+ Args:
323
+ pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting.
324
+ output_dir (str): Directory to save visualizations.
325
+
326
+ Returns:
327
+ list: List of inpainted panoramas with poses.
328
+ """
329
  inpainted_panos_w_pose = []
330
  while len(pose_dict) > 0:
331
  logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
 
416
  distances: torch.Tensor,
417
  pano_mask: torch.Tensor,
418
  ) -> tuple[torch.Tensor]:
419
+ """Inpaints missing regions in a panorama.
420
+
421
+ Args:
422
+ idx (int): Index of the panorama.
423
+ colors (torch.Tensor): RGB image tensor.
424
+ distances (torch.Tensor): Distance map tensor.
425
+ pano_mask (torch.Tensor): Mask tensor.
426
+
427
+ Returns:
428
+ tuple[torch.Tensor]: Inpainted RGB image, distances, and normals.
429
+ """
430
  mask = (pano_mask[None, ..., None] > 0.5).float()
431
  mask = mask.permute(0, 3, 1, 2)
432
  mask = dilation(mask, kernel=self.kernel)
 
448
  def preprocess_pano(
449
  self, image: Image.Image | str
450
  ) -> tuple[torch.Tensor, torch.Tensor]:
451
+ """Preprocesses a panoramic image for mesh generation.
452
+
453
+ Args:
454
+ image (Image.Image | str): Input image or path.
455
+
456
+ Returns:
457
+ tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors.
458
+ """
459
  if isinstance(image, str):
460
  image = Image.open(image)
461
 
 
479
  def pano_to_perpective(
480
  self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
481
  ) -> torch.Tensor:
482
+ """Converts a panoramic image to a perspective view.
483
+
484
+ Args:
485
+ pano_image (torch.Tensor): Panoramic image tensor.
486
+ pitch (float): Pitch angle.
487
+ yaw (float): Yaw angle.
488
+ fov (float): Field of view.
489
+
490
+ Returns:
491
+ torch.Tensor: Perspective image tensor.
492
+ """
493
  rots = dict(
494
  roll=0,
495
  pitch=pitch,
 
507
  return perspective
508
 
509
  def pano_to_cubemap(self, pano_rgb: torch.Tensor):
510
+ """Converts a panoramic RGB image to six cubemap views.
511
+
512
+ Args:
513
+ pano_rgb (torch.Tensor): Panoramic RGB image tensor.
514
+
515
+ Returns:
516
+ list: List of cubemap RGB tensors.
517
+ """
518
  # Define six canonical cube directions in (pitch, yaw)
519
  directions = [
520
  (0, 0),
 
535
  return cubemaps_rgb
536
 
537
  def save_mesh(self, output_path: str) -> None:
538
+ """Saves the mesh to a file.
539
+
540
+ Args:
541
+ output_path (str): Path to save the mesh file.
542
+ """
543
  vertices_np = self.vertices.T.cpu().numpy()
544
  colors_np = self.colors.T.cpu().numpy()
545
  faces_np = self.faces.T.cpu().numpy()
 
550
  mesh.export(output_path)
551
 
552
  def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
553
+ """Converts mesh pose to 3D Gaussian Splatting pose.
554
+
555
+ Args:
556
+ mesh_pose (torch.Tensor): Mesh pose tensor.
557
+
558
+ Returns:
559
+ np.ndarray: Converted pose matrix.
560
+ """
561
  pose = mesh_pose.clone()
562
  pose[0, :] *= -1
563
  pose[1, :] *= -1
 
574
  return c2w
575
 
576
  def __call__(self, pano_image: Image.Image | str, output_dir: str):
577
+ """Runs the pipeline to generate mesh and 3DGS data from a panoramic image.
578
+
579
+ Args:
580
+ pano_image (Image.Image | str): Input panoramic image or path.
581
+ output_dir (str): Directory to save outputs.
582
+
583
+ Returns:
584
+ None
585
+ """
586
  self.init_mesh_params()
587
  pano_rgb, pano_depth = self.preprocess_pano(pano_image)
588
  self.sup_pool = SupInfoPool()
embodied_gen/utils/enum.py CHANGED
@@ -24,11 +24,27 @@ __all__ = [
24
  "Scene3DItemEnum",
25
  "SpatialRelationEnum",
26
  "RobotItemEnum",
 
 
 
27
  ]
28
 
29
 
30
  @dataclass
31
  class RenderItems(str, Enum):
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  IMAGE = "image_color"
33
  ALPHA = "image_mask"
34
  VIEW_NORMAL = "image_view_normal"
@@ -41,6 +57,21 @@ class RenderItems(str, Enum):
41
 
42
  @dataclass
43
  class Scene3DItemEnum(str, Enum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  BACKGROUND = "background"
45
  CONTEXT = "context"
46
  ROBOT = "robot"
@@ -50,6 +81,14 @@ class Scene3DItemEnum(str, Enum):
50
 
51
  @classmethod
52
  def object_list(cls, layout_relation: dict) -> list:
 
 
 
 
 
 
 
 
53
  return (
54
  [
55
  layout_relation[cls.BACKGROUND.value],
@@ -61,6 +100,14 @@ class Scene3DItemEnum(str, Enum):
61
 
62
  @classmethod
63
  def object_mapping(cls, layout_relation):
 
 
 
 
 
 
 
 
64
  relation_mapping = {
65
  # layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
66
  layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
@@ -84,6 +131,15 @@ class Scene3DItemEnum(str, Enum):
84
 
85
  @dataclass
86
  class SpatialRelationEnum(str, Enum):
 
 
 
 
 
 
 
 
 
87
  ON = "ON" # objects on the table
88
  IN = "IN" # objects in the room
89
  INSIDE = "INSIDE" # objects inside the shelf/rack
@@ -92,6 +148,14 @@ class SpatialRelationEnum(str, Enum):
92
 
93
  @dataclass
94
  class RobotItemEnum(str, Enum):
 
 
 
 
 
 
 
 
95
  FRANKA = "franka"
96
  UR5 = "ur5"
97
  PIPER = "piper"
@@ -99,6 +163,18 @@ class RobotItemEnum(str, Enum):
99
 
100
  @dataclass
101
  class LayoutInfo(DataClassJsonMixin):
 
 
 
 
 
 
 
 
 
 
 
 
102
  tree: dict[str, list]
103
  relation: dict[str, str | list[str]]
104
  objs_desc: dict[str, str] = field(default_factory=dict)
@@ -106,3 +182,64 @@ class LayoutInfo(DataClassJsonMixin):
106
  assets: dict[str, str] = field(default_factory=dict)
107
  quality: dict[str, str] = field(default_factory=dict)
108
  position: dict[str, list[float]] = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  "Scene3DItemEnum",
25
  "SpatialRelationEnum",
26
  "RobotItemEnum",
27
+ "LayoutInfo",
28
+ "AssetType",
29
+ "SimAssetMapper",
30
  ]
31
 
32
 
33
  @dataclass
34
  class RenderItems(str, Enum):
35
+ """Enumeration of render item types for 3D scenes.
36
+
37
+ Attributes:
38
+ IMAGE: Color image.
39
+ ALPHA: Mask image.
40
+ VIEW_NORMAL: View-space normal image.
41
+ GLOBAL_NORMAL: World-space normal image.
42
+ POSITION_MAP: Position map image.
43
+ DEPTH: Depth image.
44
+ ALBEDO: Albedo image.
45
+ DIFFUSE: Diffuse image.
46
+ """
47
+
48
  IMAGE = "image_color"
49
  ALPHA = "image_mask"
50
  VIEW_NORMAL = "image_view_normal"
 
57
 
58
  @dataclass
59
  class Scene3DItemEnum(str, Enum):
60
+ """Enumeration of 3D scene item categories.
61
+
62
+ Attributes:
63
+ BACKGROUND: Background objects.
64
+ CONTEXT: Contextual objects.
65
+ ROBOT: Robot entity.
66
+ MANIPULATED_OBJS: Objects manipulated by the robot.
67
+ DISTRACTOR_OBJS: Distractor objects.
68
+ OTHERS: Other objects.
69
+
70
+ Methods:
71
+ object_list(layout_relation): Returns a list of objects in the scene.
72
+ object_mapping(layout_relation): Returns a mapping from object to category.
73
+ """
74
+
75
  BACKGROUND = "background"
76
  CONTEXT = "context"
77
  ROBOT = "robot"
 
81
 
82
  @classmethod
83
  def object_list(cls, layout_relation: dict) -> list:
84
+ """Returns a list of objects in the scene.
85
+
86
+ Args:
87
+ layout_relation: Dictionary mapping categories to objects.
88
+
89
+ Returns:
90
+ List of objects in the scene.
91
+ """
92
  return (
93
  [
94
  layout_relation[cls.BACKGROUND.value],
 
100
 
101
  @classmethod
102
  def object_mapping(cls, layout_relation):
103
+ """Returns a mapping from object to category.
104
+
105
+ Args:
106
+ layout_relation: Dictionary mapping categories to objects.
107
+
108
+ Returns:
109
+ Dictionary mapping object names to their category.
110
+ """
111
  relation_mapping = {
112
  # layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
113
  layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
 
131
 
132
  @dataclass
133
  class SpatialRelationEnum(str, Enum):
134
+ """Enumeration of spatial relations for objects in a scene.
135
+
136
+ Attributes:
137
+ ON: Objects on a surface (e.g., table).
138
+ IN: Objects in a container or room.
139
+ INSIDE: Objects inside a shelf or rack.
140
+ FLOOR: Objects on the floor.
141
+ """
142
+
143
  ON = "ON" # objects on the table
144
  IN = "IN" # objects in the room
145
  INSIDE = "INSIDE" # objects inside the shelf/rack
 
148
 
149
  @dataclass
150
  class RobotItemEnum(str, Enum):
151
+ """Enumeration of supported robot types.
152
+
153
+ Attributes:
154
+ FRANKA: Franka robot.
155
+ UR5: UR5 robot.
156
+ PIPER: Piper robot.
157
+ """
158
+
159
  FRANKA = "franka"
160
  UR5 = "ur5"
161
  PIPER = "piper"
 
163
 
164
  @dataclass
165
  class LayoutInfo(DataClassJsonMixin):
166
+ """Data structure for layout information in a 3D scene.
167
+
168
+ Attributes:
169
+ tree: Hierarchical structure of scene objects.
170
+ relation: Spatial relations between objects.
171
+ objs_desc: Descriptions of objects.
172
+ objs_mapping: Mapping from object names to categories.
173
+ assets: Asset file paths for objects.
174
+ quality: Quality information for assets.
175
+ position: Position coordinates for objects.
176
+ """
177
+
178
  tree: dict[str, list]
179
  relation: dict[str, str | list[str]]
180
  objs_desc: dict[str, str] = field(default_factory=dict)
 
182
  assets: dict[str, str] = field(default_factory=dict)
183
  quality: dict[str, str] = field(default_factory=dict)
184
  position: dict[str, list[float]] = field(default_factory=dict)
185
+
186
+
187
+ @dataclass
188
+ class AssetType(str):
189
+ """Enumeration for asset types.
190
+
191
+ Supported types:
192
+ MJCF: MuJoCo XML format.
193
+ USD: Universal Scene Description format.
194
+ URDF: Unified Robot Description Format.
195
+ MESH: Mesh file format.
196
+ """
197
+
198
+ MJCF = "mjcf"
199
+ USD = "usd"
200
+ URDF = "urdf"
201
+ MESH = "mesh"
202
+
203
+
204
+ class SimAssetMapper:
205
+ """Maps simulator names to asset types.
206
+
207
+ Provides a mapping from simulator names to their corresponding asset type.
208
+
209
+ Example:
210
+ ```py
211
+ from embodied_gen.utils.enum import SimAssetMapper
212
+ asset_type = SimAssetMapper["isaacsim"]
213
+ print(asset_type) # Output: 'usd'
214
+ ```
215
+
216
+ Methods:
217
+ __class_getitem__(key): Returns the asset type for a given simulator name.
218
+ """
219
+
220
+ _mapping = dict(
221
+ ISAACSIM=AssetType.USD,
222
+ ISAACGYM=AssetType.URDF,
223
+ MUJOCO=AssetType.MJCF,
224
+ GENESIS=AssetType.MJCF,
225
+ SAPIEN=AssetType.URDF,
226
+ PYBULLET=AssetType.URDF,
227
+ )
228
+
229
+ @classmethod
230
+ def __class_getitem__(cls, key: str):
231
+ """Returns the asset type for a given simulator name.
232
+
233
+ Args:
234
+ key: Name of the simulator.
235
+
236
+ Returns:
237
+ AssetType corresponding to the simulator.
238
+
239
+ Raises:
240
+ KeyError: If the simulator name is not recognized.
241
+ """
242
+ key = key.upper()
243
+ if key.startswith("SAPIEN"):
244
+ key = "SAPIEN"
245
+ return cls._mapping[key]
embodied_gen/utils/geometry.py CHANGED
@@ -45,13 +45,13 @@ __all__ = [
45
 
46
 
47
  def matrix_to_pose(matrix: np.ndarray) -> list[float]:
48
- """Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
49
 
50
  Args:
51
  matrix (np.ndarray): 4x4 transformation matrix.
52
 
53
  Returns:
54
- List[float]: Pose as [x, y, z, qx, qy, qz, qw].
55
  """
56
  x, y, z = matrix[:3, 3]
57
  rot_mat = matrix[:3, :3]
@@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
62
 
63
 
64
  def pose_to_matrix(pose: list[float]) -> np.ndarray:
65
- """Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
66
 
67
  Args:
68
- List[float]: Pose as [x, y, z, qx, qy, qz, qw].
69
 
70
  Returns:
71
- matrix (np.ndarray): 4x4 transformation matrix.
72
  """
73
  x, y, z, qx, qy, qz, qw = pose
74
  r = R.from_quat([qx, qy, qz, qw])
@@ -82,6 +82,16 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray:
82
  def compute_xy_bbox(
83
  vertices: np.ndarray, col_x: int = 0, col_y: int = 1
84
  ) -> list[float]:
 
 
 
 
 
 
 
 
 
 
85
  x_vals = vertices[:, col_x]
86
  y_vals = vertices[:, col_y]
87
  return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
@@ -92,6 +102,16 @@ def has_iou_conflict(
92
  placed_boxes: list[list[float]],
93
  iou_threshold: float = 0.0,
94
  ) -> bool:
 
 
 
 
 
 
 
 
 
 
95
  new_min_x, new_max_x, new_min_y, new_max_y = new_box
96
  for min_x, max_x, min_y, max_y in placed_boxes:
97
  ix1 = max(new_min_x, min_x)
@@ -105,7 +125,14 @@ def has_iou_conflict(
105
 
106
 
107
  def with_seed(seed_attr_name: str = "seed"):
108
- """A parameterized decorator that temporarily sets the random seed."""
 
 
 
 
 
 
 
109
 
110
  def decorator(func):
111
  @wraps(func)
@@ -143,6 +170,20 @@ def compute_convex_hull_path(
143
  y_axis: int = 1,
144
  z_axis: int = 2,
145
  ) -> Path:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  top_vertices = vertices[
147
  vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
148
  ]
@@ -170,6 +211,15 @@ def compute_convex_hull_path(
170
 
171
 
172
  def find_parent_node(node: str, tree: dict) -> str | None:
 
 
 
 
 
 
 
 
 
173
  for parent, children in tree.items():
174
  if any(child[0] == node for child in children):
175
  return parent
@@ -177,6 +227,16 @@ def find_parent_node(node: str, tree: dict) -> str | None:
177
 
178
 
179
  def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
 
 
 
 
 
 
 
 
 
 
180
  x1, x2, y1, y2 = box
181
  corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
182
 
@@ -187,6 +247,15 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
187
  def compute_axis_rotation_quat(
188
  axis: Literal["x", "y", "z"], angle_rad: float
189
  ) -> list[float]:
 
 
 
 
 
 
 
 
 
190
  if axis.lower() == "x":
191
  q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
192
  elif axis.lower() == "y":
@@ -202,6 +271,15 @@ def compute_axis_rotation_quat(
202
  def quaternion_multiply(
203
  init_quat: list[float], rotate_quat: list[float]
204
  ) -> list[float]:
 
 
 
 
 
 
 
 
 
205
  qx, qy, qz, qw = init_quat
206
  q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
207
  qx, qy, qz, qw = rotate_quat
@@ -217,7 +295,17 @@ def check_reachable(
217
  min_reach: float = 0.25,
218
  max_reach: float = 0.85,
219
  ) -> bool:
220
- """Check if the target point is within the reachable range."""
 
 
 
 
 
 
 
 
 
 
221
  distance = np.linalg.norm(reach_xyz - base_xyz)
222
 
223
  return min_reach < distance < max_reach
@@ -238,26 +326,31 @@ def bfs_placement(
238
  robot_dim: float = 0.12,
239
  seed: int = None,
240
  ) -> LayoutInfo:
241
- """Place objects in the layout using BFS traversal.
242
 
243
  Args:
244
- layout_file: Path to the JSON file defining the layout structure and assets.
245
- floor_margin: Z-offset for the background object, typically for objects placed on the floor.
246
- beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
247
- max_attempts: Maximum number of attempts to find a non-overlapping position for an object.
248
- init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's
249
- coordinate system with the world's (e.g., Z-up).
250
- rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects.
251
- rotate_bg: If True, apply a random rotation around the Y-axis for the background object.
252
- rotate_context: If True, apply a random rotation around the Z-axis for the context object.
253
- limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
254
- max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
255
- robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation.
256
- seed: Random seed for reproducible placement.
257
 
258
  Returns:
259
- A :class:`LayoutInfo` object containing the objects and their final computed 7D poses
260
- ([x, y, z, qx, qy, qz, qw]).
 
 
 
 
 
 
261
  """
262
  layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
263
  asset_dir = os.path.dirname(layout_file)
@@ -478,6 +571,13 @@ def bfs_placement(
478
  def compose_mesh_scene(
479
  layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
480
  ) -> None:
 
 
 
 
 
 
 
481
  object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
482
  scene = trimesh.Scene()
483
  for node in layout_info.assets:
@@ -505,6 +605,16 @@ def compose_mesh_scene(
505
  def compute_pinhole_intrinsics(
506
  image_w: int, image_h: int, fov_deg: float
507
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
508
  fov_rad = np.deg2rad(fov_deg)
509
  fx = image_w / (2 * np.tan(fov_rad / 2))
510
  fy = fx # assuming square pixels
 
45
 
46
 
47
  def matrix_to_pose(matrix: np.ndarray) -> list[float]:
48
+ """Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
49
 
50
  Args:
51
  matrix (np.ndarray): 4x4 transformation matrix.
52
 
53
  Returns:
54
+ list[float]: Pose as [x, y, z, qx, qy, qz, qw].
55
  """
56
  x, y, z = matrix[:3, 3]
57
  rot_mat = matrix[:3, :3]
 
62
 
63
 
64
  def pose_to_matrix(pose: list[float]) -> np.ndarray:
65
+ """Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
66
 
67
  Args:
68
+ pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw].
69
 
70
  Returns:
71
+ np.ndarray: 4x4 transformation matrix.
72
  """
73
  x, y, z, qx, qy, qz, qw = pose
74
  r = R.from_quat([qx, qy, qz, qw])
 
82
  def compute_xy_bbox(
83
  vertices: np.ndarray, col_x: int = 0, col_y: int = 1
84
  ) -> list[float]:
85
+ """Computes the bounding box in XY plane for given vertices.
86
+
87
+ Args:
88
+ vertices (np.ndarray): Vertex coordinates.
89
+ col_x (int, optional): Column index for X.
90
+ col_y (int, optional): Column index for Y.
91
+
92
+ Returns:
93
+ list[float]: [min_x, max_x, min_y, max_y]
94
+ """
95
  x_vals = vertices[:, col_x]
96
  y_vals = vertices[:, col_y]
97
  return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
 
102
  placed_boxes: list[list[float]],
103
  iou_threshold: float = 0.0,
104
  ) -> bool:
105
+ """Checks for intersection-over-union conflict between boxes.
106
+
107
+ Args:
108
+ new_box (list[float]): New box coordinates.
109
+ placed_boxes (list[list[float]]): List of placed box coordinates.
110
+ iou_threshold (float, optional): IOU threshold.
111
+
112
+ Returns:
113
+ bool: True if conflict exists, False otherwise.
114
+ """
115
  new_min_x, new_max_x, new_min_y, new_max_y = new_box
116
  for min_x, max_x, min_y, max_y in placed_boxes:
117
  ix1 = max(new_min_x, min_x)
 
125
 
126
 
127
  def with_seed(seed_attr_name: str = "seed"):
128
+ """Decorator to temporarily set the random seed for reproducibility.
129
+
130
+ Args:
131
+ seed_attr_name (str, optional): Name of the seed argument.
132
+
133
+ Returns:
134
+ function: Decorator function.
135
+ """
136
 
137
  def decorator(func):
138
  @wraps(func)
 
170
  y_axis: int = 1,
171
  z_axis: int = 2,
172
  ) -> Path:
173
+ """Computes a dense convex hull path for the top surface of a mesh.
174
+
175
+ Args:
176
+ vertices (np.ndarray): Mesh vertices.
177
+ z_threshold (float, optional): Z threshold for top surface.
178
+ interp_per_edge (int, optional): Interpolation points per edge.
179
+ margin (float, optional): Margin for polygon buffer.
180
+ x_axis (int, optional): X axis index.
181
+ y_axis (int, optional): Y axis index.
182
+ z_axis (int, optional): Z axis index.
183
+
184
+ Returns:
185
+ Path: Matplotlib path object for the convex hull.
186
+ """
187
  top_vertices = vertices[
188
  vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
189
  ]
 
211
 
212
 
213
  def find_parent_node(node: str, tree: dict) -> str | None:
214
+ """Finds the parent node of a given node in a tree.
215
+
216
+ Args:
217
+ node (str): Node name.
218
+ tree (dict): Tree structure.
219
+
220
+ Returns:
221
+ str | None: Parent node name or None.
222
+ """
223
  for parent, children in tree.items():
224
  if any(child[0] == node for child in children):
225
  return parent
 
227
 
228
 
229
  def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
230
+ """Checks if at least `threshold` corners of a box are inside a hull.
231
+
232
+ Args:
233
+ hull (Path): Convex hull path.
234
+ box (list): Box coordinates [x1, x2, y1, y2].
235
+ threshold (int, optional): Minimum corners inside.
236
+
237
+ Returns:
238
+ bool: True if enough corners are inside.
239
+ """
240
  x1, x2, y1, y2 = box
241
  corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
242
 
 
247
  def compute_axis_rotation_quat(
248
  axis: Literal["x", "y", "z"], angle_rad: float
249
  ) -> list[float]:
250
+ """Computes quaternion for rotation around a given axis.
251
+
252
+ Args:
253
+ axis (Literal["x", "y", "z"]): Axis of rotation.
254
+ angle_rad (float): Rotation angle in radians.
255
+
256
+ Returns:
257
+ list[float]: Quaternion [x, y, z, w].
258
+ """
259
  if axis.lower() == "x":
260
  q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
261
  elif axis.lower() == "y":
 
271
  def quaternion_multiply(
272
  init_quat: list[float], rotate_quat: list[float]
273
  ) -> list[float]:
274
+ """Multiplies two quaternions.
275
+
276
+ Args:
277
+ init_quat (list[float]): Initial quaternion [x, y, z, w].
278
+ rotate_quat (list[float]): Rotation quaternion [x, y, z, w].
279
+
280
+ Returns:
281
+ list[float]: Resulting quaternion [x, y, z, w].
282
+ """
283
  qx, qy, qz, qw = init_quat
284
  q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
285
  qx, qy, qz, qw = rotate_quat
 
295
  min_reach: float = 0.25,
296
  max_reach: float = 0.85,
297
  ) -> bool:
298
+ """Checks if the target point is within the reachable range.
299
+
300
+ Args:
301
+ base_xyz (np.ndarray): Base position.
302
+ reach_xyz (np.ndarray): Target position.
303
+ min_reach (float, optional): Minimum reach distance.
304
+ max_reach (float, optional): Maximum reach distance.
305
+
306
+ Returns:
307
+ bool: True if reachable, False otherwise.
308
+ """
309
  distance = np.linalg.norm(reach_xyz - base_xyz)
310
 
311
  return min_reach < distance < max_reach
 
326
  robot_dim: float = 0.12,
327
  seed: int = None,
328
  ) -> LayoutInfo:
329
+ """Places objects in a scene layout using BFS traversal.
330
 
331
  Args:
332
+ layout_file (str): Path to layout JSON file generated from `layout-cli`.
333
+ floor_margin (float, optional): Z-offset for objects placed on the floor.
334
+ beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
335
+ max_attempts (int, optional): Max attempts for a non-overlapping placement.
336
+ init_rpy (tuple, optional): Initial rotation (rpy).
337
+ rotate_objs (bool, optional): Whether to random rotate objects.
338
+ rotate_bg (bool, optional): Whether to random rotate background.
339
+ rotate_context (bool, optional): Whether to random rotate context asset.
340
+ limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
341
+ max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
342
+ robot_dim (float, optional): The approximate robot size.
343
+ seed (int, optional): Random seed for reproducible placement.
 
344
 
345
  Returns:
346
+ LayoutInfo: Layout information with object poses.
347
+
348
+ Example:
349
+ ```py
350
+ from embodied_gen.utils.geometry import bfs_placement
351
+ layout = bfs_placement("scene_layout.json", seed=42)
352
+ print(layout.position)
353
+ ```
354
  """
355
  layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
356
  asset_dir = os.path.dirname(layout_file)
 
571
  def compose_mesh_scene(
572
  layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
573
  ) -> None:
574
+ """Composes a mesh scene from layout information and saves to file.
575
+
576
+ Args:
577
+ layout_info (LayoutInfo): Layout information.
578
+ out_scene_path (str): Output scene file path.
579
+ with_bg (bool, optional): Include background mesh.
580
+ """
581
  object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
582
  scene = trimesh.Scene()
583
  for node in layout_info.assets:
 
605
  def compute_pinhole_intrinsics(
606
  image_w: int, image_h: int, fov_deg: float
607
  ) -> np.ndarray:
608
+ """Computes pinhole camera intrinsic matrix from image size and FOV.
609
+
610
+ Args:
611
+ image_w (int): Image width.
612
+ image_h (int): Image height.
613
+ fov_deg (float): Field of view in degrees.
614
+
615
+ Returns:
616
+ np.ndarray: Intrinsic matrix K.
617
+ """
618
  fov_rad = np.deg2rad(fov_deg)
619
  fx = image_w / (2 * np.tan(fov_rad / 2))
620
  fy = fx # assuming square pixels
embodied_gen/utils/gpt_clients.py CHANGED
@@ -45,7 +45,35 @@ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
45
 
46
 
47
  class GPTclient:
48
- """A client to interact with the GPT model via OpenAI or Azure API."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def __init__(
51
  self,
@@ -82,6 +110,7 @@ class GPTclient:
82
  stop=(stop_after_attempt(10) | stop_after_delay(30)),
83
  )
84
  def completion_with_backoff(self, **kwargs):
 
85
  return self.client.chat.completions.create(**kwargs)
86
 
87
  def query(
@@ -91,19 +120,16 @@ class GPTclient:
91
  system_role: Optional[str] = None,
92
  params: Optional[dict] = None,
93
  ) -> Optional[str]:
94
- """Queries the GPT model with a text and optional image prompts.
95
 
96
  Args:
97
- text_prompt (str): The main text input that the model responds to.
98
- image_base64 (Optional[List[str]]): A list of image base64 strings
99
- or local image paths or PIL.Image to accompany the text prompt.
100
- system_role (Optional[str]): Optional system-level instructions
101
- that specify the behavior of the assistant.
102
- params (Optional[dict]): Additional parameters for GPT setting.
103
 
104
  Returns:
105
- Optional[str]: The response content generated by the model based on
106
- the prompt. Returns `None` if an error occurs.
107
  """
108
  if system_role is None:
109
  system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
@@ -177,7 +203,11 @@ class GPTclient:
177
  return response
178
 
179
  def check_connection(self) -> None:
180
- """Check whether the GPT API connection is working."""
 
 
 
 
181
  try:
182
  response = self.completion_with_backoff(
183
  messages=[
 
45
 
46
 
47
  class GPTclient:
48
+ """A client to interact with GPT models via OpenAI or Azure API.
49
+
50
+ Supports text and image prompts, connection checking, and configurable parameters.
51
+
52
+ Args:
53
+ endpoint (str): API endpoint URL.
54
+ api_key (str): API key for authentication.
55
+ model_name (str, optional): Model name to use.
56
+ api_version (str, optional): API version (for Azure).
57
+ check_connection (bool, optional): Whether to check API connection.
58
+ verbose (bool, optional): Enable verbose logging.
59
+
60
+ Example:
61
+ ```sh
62
+ export ENDPOINT="https://yfb-openai-sweden.openai.azure.com"
63
+ export API_KEY="xxxxxx"
64
+ export API_VERSION="2025-03-01-preview"
65
+ export MODEL_NAME="yfb-gpt-4o-sweden"
66
+ ```
67
+ ```py
68
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
69
+
70
+ response = GPT_CLIENT.query("Describe the physics of a falling apple.")
71
+ response = GPT_CLIENT.query(
72
+ text_prompt="Describe the content in each image."
73
+ image_base64=["path/to/image1.png", "path/to/image2.jpg"],
74
+ )
75
+ ```
76
+ """
77
 
78
  def __init__(
79
  self,
 
110
  stop=(stop_after_attempt(10) | stop_after_delay(30)),
111
  )
112
  def completion_with_backoff(self, **kwargs):
113
+ """Performs a chat completion request with retry/backoff."""
114
  return self.client.chat.completions.create(**kwargs)
115
 
116
  def query(
 
120
  system_role: Optional[str] = None,
121
  params: Optional[dict] = None,
122
  ) -> Optional[str]:
123
+ """Queries the GPT model with text and optional image prompts.
124
 
125
  Args:
126
+ text_prompt (str): Main text input.
127
+ image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images.
128
+ system_role (Optional[str], optional): System-level instructions.
129
+ params (Optional[dict], optional): Additional GPT parameters.
 
 
130
 
131
  Returns:
132
+ Optional[str]: Model response content, or None if error.
 
133
  """
134
  if system_role is None:
135
  system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
 
203
  return response
204
 
205
  def check_connection(self) -> None:
206
+ """Checks whether the GPT API connection is working.
207
+
208
+ Raises:
209
+ ConnectionError: If connection fails.
210
+ """
211
  try:
212
  response = self.completion_with_backoff(
213
  messages=[
embodied_gen/utils/process_media.py CHANGED
@@ -69,6 +69,40 @@ def render_asset3d(
69
  no_index_file: bool = False,
70
  with_mtl: bool = True,
71
  ) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  input_args = dict(
73
  mesh_path=mesh_path,
74
  output_root=output_root,
@@ -95,6 +129,13 @@ def render_asset3d(
95
 
96
 
97
  def merge_images_video(color_images, normal_images, output_path) -> None:
 
 
 
 
 
 
 
98
  width = color_images[0].shape[1]
99
  combined_video = [
100
  np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
@@ -108,7 +149,13 @@ def merge_images_video(color_images, normal_images, output_path) -> None:
108
  def merge_video_video(
109
  video_path1: str, video_path2: str, output_path: str
110
  ) -> None:
111
- """Merge two videos by the left half and the right half of the videos."""
 
 
 
 
 
 
112
  clip1 = VideoFileClip(video_path1)
113
  clip2 = VideoFileClip(video_path2)
114
 
@@ -127,6 +174,16 @@ def filter_small_connected_components(
127
  area_ratio: float,
128
  connectivity: int = 8,
129
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
130
  if isinstance(mask, Image.Image):
131
  mask = np.array(mask)
132
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
@@ -152,6 +209,16 @@ def filter_image_small_connected_components(
152
  area_ratio: float = 10,
153
  connectivity: int = 8,
154
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
155
  if isinstance(image, Image.Image):
156
  image = image.convert("RGBA")
157
  image = np.array(image)
@@ -169,6 +236,24 @@ def combine_images_to_grid(
169
  target_wh: tuple[int, int] = (512, 512),
170
  image_mode: str = "RGB",
171
  ) -> list[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  n_images = len(images)
173
  if n_images == 1:
174
  return images
@@ -196,6 +281,19 @@ def combine_images_to_grid(
196
 
197
 
198
  class SceneTreeVisualizer:
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def __init__(self, layout_info: LayoutInfo) -> None:
200
  self.tree = layout_info.tree
201
  self.relation = layout_info.relation
@@ -274,6 +372,14 @@ class SceneTreeVisualizer:
274
  dpi=300,
275
  title: str = "Scene 3D Hierarchy Tree",
276
  ):
 
 
 
 
 
 
 
 
277
  node_colors = [
278
  self.role_colors[self._get_node_role(n)] for n in self.G.nodes
279
  ]
@@ -350,6 +456,14 @@ class SceneTreeVisualizer:
350
 
351
 
352
  def load_scene_dict(file_path: str) -> dict:
 
 
 
 
 
 
 
 
353
  scene_dict = {}
354
  with open(file_path, "r", encoding='utf-8') as f:
355
  for line in f:
@@ -363,12 +477,28 @@ def load_scene_dict(file_path: str) -> dict:
363
 
364
 
365
  def is_image_file(filename: str) -> bool:
 
 
 
 
 
 
 
 
366
  mime_type, _ = mimetypes.guess_type(filename)
367
 
368
  return mime_type is not None and mime_type.startswith('image')
369
 
370
 
371
  def parse_text_prompts(prompts: list[str]) -> list[str]:
 
 
 
 
 
 
 
 
372
  if len(prompts) == 1 and prompts[0].endswith(".txt"):
373
  with open(prompts[0], "r") as f:
374
  prompts = [
@@ -386,13 +516,18 @@ def alpha_blend_rgba(
386
  """Alpha blends a foreground RGBA image over a background RGBA image.
387
 
388
  Args:
389
- fg_image: Foreground image. Can be a file path (str), a PIL Image,
390
- or a NumPy ndarray.
391
- bg_image: Background image. Can be a file path (str), a PIL Image,
392
- or a NumPy ndarray.
393
 
394
  Returns:
395
- A PIL Image representing the alpha-blended result in RGBA mode.
 
 
 
 
 
 
 
396
  """
397
  if isinstance(fg_image, str):
398
  fg_image = Image.open(fg_image)
@@ -421,13 +556,11 @@ def check_object_edge_truncated(
421
  """Checks if a binary object mask is truncated at the image edges.
422
 
423
  Args:
424
- mask: A 2D binary NumPy array where nonzero values indicate the object region.
425
- edge_threshold: Number of pixels from each image edge to consider for truncation.
426
- Defaults to 5.
427
 
428
  Returns:
429
- True if the object is fully enclosed (not truncated).
430
- False if the object touches or crosses any image boundary.
431
  """
432
  top = mask[:edge_threshold, :].any()
433
  bottom = mask[-edge_threshold:, :].any()
@@ -440,6 +573,22 @@ def check_object_edge_truncated(
440
  def vcat_pil_images(
441
  images: list[Image.Image], image_mode: str = "RGB"
442
  ) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  widths, heights = zip(*(img.size for img in images))
444
  total_height = sum(heights)
445
  max_width = max(widths)
 
69
  no_index_file: bool = False,
70
  with_mtl: bool = True,
71
  ) -> list[str]:
72
+ """Renders a 3D mesh asset and returns output image paths.
73
+
74
+ Args:
75
+ mesh_path (str): Path to the mesh file.
76
+ output_root (str): Directory to save outputs.
77
+ distance (float, optional): Camera distance.
78
+ num_images (int, optional): Number of views to render.
79
+ elevation (list[float], optional): Camera elevation angles.
80
+ pbr_light_factor (float, optional): PBR lighting factor.
81
+ return_key (str, optional): Glob pattern for output images.
82
+ output_subdir (str, optional): Subdirectory for outputs.
83
+ gen_color_mp4 (bool, optional): Generate color MP4 video.
84
+ gen_viewnormal_mp4 (bool, optional): Generate view normal MP4.
85
+ gen_glonormal_mp4 (bool, optional): Generate global normal MP4.
86
+ no_index_file (bool, optional): Skip index file saving.
87
+ with_mtl (bool, optional): Use mesh material.
88
+
89
+ Returns:
90
+ list[str]: List of output image file paths.
91
+
92
+ Example:
93
+ ```py
94
+ from embodied_gen.utils.process_media import render_asset3d
95
+
96
+ image_paths = render_asset3d(
97
+ mesh_path="path_to_mesh.obj",
98
+ output_root="path_to_save_dir",
99
+ num_images=6,
100
+ elevation=(30, -30),
101
+ output_subdir="renders",
102
+ no_index_file=True,
103
+ )
104
+ ```
105
+ """
106
  input_args = dict(
107
  mesh_path=mesh_path,
108
  output_root=output_root,
 
129
 
130
 
131
  def merge_images_video(color_images, normal_images, output_path) -> None:
132
+ """Merges color and normal images into a video.
133
+
134
+ Args:
135
+ color_images (list[np.ndarray]): List of color images.
136
+ normal_images (list[np.ndarray]): List of normal images.
137
+ output_path (str): Path to save the output video.
138
+ """
139
  width = color_images[0].shape[1]
140
  combined_video = [
141
  np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
 
149
  def merge_video_video(
150
  video_path1: str, video_path2: str, output_path: str
151
  ) -> None:
152
+ """Merges two videos by combining their left and right halves.
153
+
154
+ Args:
155
+ video_path1 (str): Path to first video.
156
+ video_path2 (str): Path to second video.
157
+ output_path (str): Path to save the merged video.
158
+ """
159
  clip1 = VideoFileClip(video_path1)
160
  clip2 = VideoFileClip(video_path2)
161
 
 
174
  area_ratio: float,
175
  connectivity: int = 8,
176
  ) -> np.ndarray:
177
+ """Removes small connected components from a binary mask.
178
+
179
+ Args:
180
+ mask (Union[Image.Image, np.ndarray]): Input mask.
181
+ area_ratio (float): Minimum area ratio for components.
182
+ connectivity (int, optional): Connectivity for labeling.
183
+
184
+ Returns:
185
+ np.ndarray: Mask with small components removed.
186
+ """
187
  if isinstance(mask, Image.Image):
188
  mask = np.array(mask)
189
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
 
209
  area_ratio: float = 10,
210
  connectivity: int = 8,
211
  ) -> np.ndarray:
212
+ """Removes small connected components from the alpha channel of an image.
213
+
214
+ Args:
215
+ image (Union[Image.Image, np.ndarray]): Input image.
216
+ area_ratio (float, optional): Minimum area ratio.
217
+ connectivity (int, optional): Connectivity for labeling.
218
+
219
+ Returns:
220
+ np.ndarray: Image with filtered alpha channel.
221
+ """
222
  if isinstance(image, Image.Image):
223
  image = image.convert("RGBA")
224
  image = np.array(image)
 
236
  target_wh: tuple[int, int] = (512, 512),
237
  image_mode: str = "RGB",
238
  ) -> list[Image.Image]:
239
+ """Combines multiple images into a grid.
240
+
241
+ Args:
242
+ images (list[str | Image.Image]): List of image paths or PIL Images.
243
+ cat_row_col (tuple[int, int], optional): Grid rows and columns.
244
+ target_wh (tuple[int, int], optional): Target image size.
245
+ image_mode (str, optional): Image mode.
246
+
247
+ Returns:
248
+ list[Image.Image]: List containing the grid image.
249
+
250
+ Example:
251
+ ```py
252
+ from embodied_gen.utils.process_media import combine_images_to_grid
253
+ grid = combine_images_to_grid(["img1.png", "img2.png"])
254
+ grid[0].save("grid.png")
255
+ ```
256
+ """
257
  n_images = len(images)
258
  if n_images == 1:
259
  return images
 
281
 
282
 
283
  class SceneTreeVisualizer:
284
+ """Visualizes a scene tree layout using networkx and matplotlib.
285
+
286
+ Args:
287
+ layout_info (LayoutInfo): Layout information for the scene.
288
+
289
+ Example:
290
+ ```py
291
+ from embodied_gen.utils.process_media import SceneTreeVisualizer
292
+ visualizer = SceneTreeVisualizer(layout_info)
293
+ visualizer.render(save_path="tree.png")
294
+ ```
295
+ """
296
+
297
  def __init__(self, layout_info: LayoutInfo) -> None:
298
  self.tree = layout_info.tree
299
  self.relation = layout_info.relation
 
372
  dpi=300,
373
  title: str = "Scene 3D Hierarchy Tree",
374
  ):
375
+ """Renders the scene tree and saves to file.
376
+
377
+ Args:
378
+ save_path (str): Path to save the rendered image.
379
+ figsize (tuple, optional): Figure size.
380
+ dpi (int, optional): Image DPI.
381
+ title (str, optional): Plot image title.
382
+ """
383
  node_colors = [
384
  self.role_colors[self._get_node_role(n)] for n in self.G.nodes
385
  ]
 
456
 
457
 
458
  def load_scene_dict(file_path: str) -> dict:
459
+ """Loads a scene description dictionary from a file.
460
+
461
+ Args:
462
+ file_path (str): Path to the scene description file.
463
+
464
+ Returns:
465
+ dict: Mapping from scene ID to description.
466
+ """
467
  scene_dict = {}
468
  with open(file_path, "r", encoding='utf-8') as f:
469
  for line in f:
 
477
 
478
 
479
  def is_image_file(filename: str) -> bool:
480
+ """Checks if a filename is an image file.
481
+
482
+ Args:
483
+ filename (str): Filename to check.
484
+
485
+ Returns:
486
+ bool: True if image file, False otherwise.
487
+ """
488
  mime_type, _ = mimetypes.guess_type(filename)
489
 
490
  return mime_type is not None and mime_type.startswith('image')
491
 
492
 
493
  def parse_text_prompts(prompts: list[str]) -> list[str]:
494
+ """Parses text prompts from a list or file.
495
+
496
+ Args:
497
+ prompts (list[str]): List of prompts or a file path.
498
+
499
+ Returns:
500
+ list[str]: List of parsed prompts.
501
+ """
502
  if len(prompts) == 1 and prompts[0].endswith(".txt"):
503
  with open(prompts[0], "r") as f:
504
  prompts = [
 
516
  """Alpha blends a foreground RGBA image over a background RGBA image.
517
 
518
  Args:
519
+ fg_image: Foreground image (str, PIL Image, or ndarray).
520
+ bg_image: Background image (str, PIL Image, or ndarray).
 
 
521
 
522
  Returns:
523
+ Image.Image: Alpha-blended RGBA image.
524
+
525
+ Example:
526
+ ```py
527
+ from embodied_gen.utils.process_media import alpha_blend_rgba
528
+ result = alpha_blend_rgba("fg.png", "bg.png")
529
+ result.save("blended.png")
530
+ ```
531
  """
532
  if isinstance(fg_image, str):
533
  fg_image = Image.open(fg_image)
 
556
  """Checks if a binary object mask is truncated at the image edges.
557
 
558
  Args:
559
+ mask (np.ndarray): 2D binary mask.
560
+ edge_threshold (int, optional): Edge pixel threshold.
 
561
 
562
  Returns:
563
+ bool: True if object is fully enclosed, False if truncated.
 
564
  """
565
  top = mask[:edge_threshold, :].any()
566
  bottom = mask[-edge_threshold:, :].any()
 
573
  def vcat_pil_images(
574
  images: list[Image.Image], image_mode: str = "RGB"
575
  ) -> Image.Image:
576
+ """Vertically concatenates a list of PIL images.
577
+
578
+ Args:
579
+ images (list[Image.Image]): List of images.
580
+ image_mode (str, optional): Image mode.
581
+
582
+ Returns:
583
+ Image.Image: Vertically concatenated image.
584
+
585
+ Example:
586
+ ```py
587
+ from embodied_gen.utils.process_media import vcat_pil_images
588
+ img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")])
589
+ img.save("vcat.png")
590
+ ```
591
+ """
592
  widths, heights = zip(*(img.size for img in images))
593
  total_height = sum(heights)
594
  max_width = max(widths)
embodied_gen/utils/simulation.py CHANGED
@@ -69,6 +69,21 @@ def load_actor_from_urdf(
69
  update_mass: bool = False,
70
  scale: float | np.ndarray = 1.0,
71
  ) -> sapien.pysapien.Entity:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
73
  local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
74
  if origin_tag is not None:
@@ -154,14 +169,17 @@ def load_assets_from_layout_file(
154
  init_quat: list[float] = [0, 0, 0, 1],
155
  env_idx: int = None,
156
  ) -> dict[str, sapien.pysapien.Entity]:
157
- """Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
158
 
159
  Args:
160
- scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
161
- layout (str): The layout file path.
162
- z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
163
- init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
164
- env_idx (int): Environment index for multi-environment setup.
 
 
 
165
  """
166
  asset_root = os.path.dirname(layout)
167
  layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
@@ -206,6 +224,19 @@ def load_mani_skill_robot(
206
  control_mode: str = "pd_joint_pos",
207
  backend_str: tuple[str, str] = ("cpu", "gpu"),
208
  ) -> BaseAgent:
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  from mani_skill.agents import REGISTERED_AGENTS
210
  from mani_skill.envs.scene import ManiSkillScene
211
  from mani_skill.envs.utils.system.backend import (
@@ -278,14 +309,14 @@ def render_images(
278
  ]
279
  ] = None,
280
  ) -> dict[str, Image.Image]:
281
- """Render images from a given sapien camera.
282
 
283
  Args:
284
- camera (sapien.render.RenderCameraComponent): The camera to render from.
285
- render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
286
 
287
  Returns:
288
- Dict[str, Image.Image]: Dictionary of rendered images.
289
  """
290
  if render_keys is None:
291
  render_keys = [
@@ -341,11 +372,33 @@ def render_images(
341
 
342
 
343
  class SapienSceneManager:
344
- """A class to manage SAPIEN simulator."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  def __init__(
347
  self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
348
  ) -> None:
 
 
 
 
 
 
 
349
  self.sim_freq = sim_freq
350
  self.ray_tracing = ray_tracing
351
  self.device = device
@@ -355,7 +408,11 @@ class SapienSceneManager:
355
  self.actors: dict[str, sapien.pysapien.Entity] = {}
356
 
357
  def _setup_scene(self) -> sapien.Scene:
358
- """Set up the SAPIEN scene with lighting and ground."""
 
 
 
 
359
  # Ray tracing settings
360
  if self.ray_tracing:
361
  sapien.render.set_camera_shader_dir("rt")
@@ -397,6 +454,18 @@ class SapienSceneManager:
397
  render_keys: list[str],
398
  sim_steps_per_control: int = 1,
399
  ) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
400
  agent.set_action(action)
401
  frames = defaultdict(list)
402
  for _ in range(sim_steps_per_control):
@@ -417,13 +486,13 @@ class SapienSceneManager:
417
  image_hw: tuple[int, int],
418
  fovy_deg: float,
419
  ) -> sapien.render.RenderCameraComponent:
420
- """Create a single camera in the scene.
421
 
422
  Args:
423
- cam_name (str): Name of the camera.
424
- pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
425
- image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
426
- fovy_deg (float): Field of view in degrees for cameras.
427
 
428
  Returns:
429
  sapien.render.RenderCameraComponent: The created camera.
@@ -456,15 +525,15 @@ class SapienSceneManager:
456
  """Initialize multiple cameras arranged in a circle.
457
 
458
  Args:
459
- num_cameras (int): Number of cameras to create.
460
- radius (float): Radius of the camera circle.
461
- height (float): Fixed Z-coordinate of the cameras.
462
- target_pt (list[float]): 3D point (x, y, z) that cameras look at.
463
- image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
464
- fovy_deg (float): Field of view in degrees for cameras.
465
 
466
  Returns:
467
- List[sapien.render.RenderCameraComponent]: List of created cameras.
468
  """
469
  angle_step = 2 * np.pi / num_cameras
470
  world_up_vec = np.array([0.0, 0.0, 1.0])
@@ -510,6 +579,19 @@ class SapienSceneManager:
510
 
511
 
512
  class FrankaPandaGrasper(object):
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  def __init__(
514
  self,
515
  agent: BaseAgent,
@@ -518,6 +600,7 @@ class FrankaPandaGrasper(object):
518
  joint_acc_limits: float = 1.0,
519
  finger_length: float = 0.025,
520
  ) -> None:
 
521
  self.agent = agent
522
  self.robot = agent.robot
523
  self.control_freq = control_freq
@@ -553,6 +636,15 @@ class FrankaPandaGrasper(object):
553
  gripper_state: Literal[-1, 1],
554
  n_step: int = 10,
555
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
556
  qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
557
  actions = []
558
  for _ in range(n_step):
@@ -571,6 +663,20 @@ class FrankaPandaGrasper(object):
571
  action_key: str = "position",
572
  env_idx: int = 0,
573
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  result = self.planners[env_idx].plan_qpos_to_pose(
575
  np.concatenate([pose.p, pose.q]),
576
  self.robot.get_qpos().cpu().numpy()[0],
@@ -608,6 +714,17 @@ class FrankaPandaGrasper(object):
608
  offset: tuple[float, float, float] = [0, 0, -0.05],
609
  env_idx: int = 0,
610
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
611
  physx_rigid = actor.components[1]
612
  mesh = get_component_mesh(physx_rigid, to_world_frame=True)
613
  obb = mesh.bounding_box_oriented
 
69
  update_mass: bool = False,
70
  scale: float | np.ndarray = 1.0,
71
  ) -> sapien.pysapien.Entity:
72
+ """Load an sapien actor from a URDF file and add it to the scene.
73
+
74
+ Args:
75
+ scene (sapien.Scene | ManiSkillScene): The simulation scene.
76
+ file_path (str): Path to the URDF file.
77
+ pose (sapien.Pose | None): Initial pose of the actor.
78
+ env_idx (int): Environment index for multi-env setup.
79
+ use_static (bool): Whether the actor is static.
80
+ update_mass (bool): Whether to update the actor's mass from URDF.
81
+ scale (float | np.ndarray): Scale factor for the actor.
82
+
83
+ Returns:
84
+ sapien.pysapien.Entity: The created actor entity.
85
+ """
86
+
87
  def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
88
  local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
89
  if origin_tag is not None:
 
169
  init_quat: list[float] = [0, 0, 0, 1],
170
  env_idx: int = None,
171
  ) -> dict[str, sapien.pysapien.Entity]:
172
+ """Load assets from an EmbodiedGen layout file and create sapien actors in the scene.
173
 
174
  Args:
175
+ scene (ManiSkillScene | sapien.Scene): The sapien simulation scene.
176
+ layout (str): Path to the embodiedgen layout file.
177
+ z_offset (float): Z offset for non-context objects.
178
+ init_quat (list[float]): Initial quaternion for orientation.
179
+ env_idx (int): Environment index.
180
+
181
+ Returns:
182
+ dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities.
183
  """
184
  asset_root = os.path.dirname(layout)
185
  layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
 
224
  control_mode: str = "pd_joint_pos",
225
  backend_str: tuple[str, str] = ("cpu", "gpu"),
226
  ) -> BaseAgent:
227
+ """Load a ManiSkill robot agent into the scene.
228
+
229
+ Args:
230
+ scene (sapien.Scene | ManiSkillScene): The simulation scene.
231
+ layout (LayoutInfo | str): Layout info or path to layout file.
232
+ control_freq (int): Control frequency.
233
+ robot_init_qpos_noise (float): Noise for initial joint positions.
234
+ control_mode (str): Robot control mode.
235
+ backend_str (tuple[str, str]): Simulation/render backend.
236
+
237
+ Returns:
238
+ BaseAgent: The loaded robot agent.
239
+ """
240
  from mani_skill.agents import REGISTERED_AGENTS
241
  from mani_skill.envs.scene import ManiSkillScene
242
  from mani_skill.envs.utils.system.backend import (
 
309
  ]
310
  ] = None,
311
  ) -> dict[str, Image.Image]:
312
+ """Render images from a given SAPIEN camera.
313
 
314
  Args:
315
+ camera (sapien.render.RenderCameraComponent): Camera to render from.
316
+ render_keys (list[str], optional): Types of images to render.
317
 
318
  Returns:
319
+ dict[str, Image.Image]: Dictionary of rendered images.
320
  """
321
  if render_keys is None:
322
  render_keys = [
 
372
 
373
 
374
  class SapienSceneManager:
375
+ """Manages SAPIEN simulation scenes, cameras, and rendering.
376
+
377
+ This class provides utilities for setting up scenes, adding cameras,
378
+ stepping simulation, and rendering images.
379
+
380
+ Attributes:
381
+ sim_freq (int): Simulation frequency.
382
+ ray_tracing (bool): Whether to use ray tracing.
383
+ device (str): Device for simulation.
384
+ renderer (sapien.SapienRenderer): SAPIEN renderer.
385
+ scene (sapien.Scene): Simulation scene.
386
+ cameras (list): List of camera components.
387
+ actors (dict): Mapping of actor names to entities.
388
+
389
+ Example see `embodied_gen/scripts/simulate_sapien.py`.
390
+ """
391
 
392
  def __init__(
393
  self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
394
  ) -> None:
395
+ """Initialize the scene manager.
396
+
397
+ Args:
398
+ sim_freq (int): Simulation frequency.
399
+ ray_tracing (bool): Enable ray tracing.
400
+ device (str): Device for simulation.
401
+ """
402
  self.sim_freq = sim_freq
403
  self.ray_tracing = ray_tracing
404
  self.device = device
 
408
  self.actors: dict[str, sapien.pysapien.Entity] = {}
409
 
410
  def _setup_scene(self) -> sapien.Scene:
411
+ """Set up the SAPIEN scene with lighting and ground.
412
+
413
+ Returns:
414
+ sapien.Scene: The initialized scene.
415
+ """
416
  # Ray tracing settings
417
  if self.ray_tracing:
418
  sapien.render.set_camera_shader_dir("rt")
 
454
  render_keys: list[str],
455
  sim_steps_per_control: int = 1,
456
  ) -> dict:
457
+ """Step the simulation and render images from cameras.
458
+
459
+ Args:
460
+ agent (BaseAgent): The robot agent.
461
+ action (torch.Tensor): Action to apply.
462
+ cameras (list): List of camera components.
463
+ render_keys (list[str]): Types of images to render.
464
+ sim_steps_per_control (int): Simulation steps per control.
465
+
466
+ Returns:
467
+ dict: Dictionary of rendered frames per camera.
468
+ """
469
  agent.set_action(action)
470
  frames = defaultdict(list)
471
  for _ in range(sim_steps_per_control):
 
486
  image_hw: tuple[int, int],
487
  fovy_deg: float,
488
  ) -> sapien.render.RenderCameraComponent:
489
+ """Create a camera in the scene.
490
 
491
  Args:
492
+ cam_name (str): Camera name.
493
+ pose (sapien.Pose): Camera pose.
494
+ image_hw (tuple[int, int]): Image resolution (height, width).
495
+ fovy_deg (float): Field of view in degrees.
496
 
497
  Returns:
498
  sapien.render.RenderCameraComponent: The created camera.
 
525
  """Initialize multiple cameras arranged in a circle.
526
 
527
  Args:
528
+ num_cameras (int): Number of cameras.
529
+ radius (float): Circle radius.
530
+ height (float): Camera height.
531
+ target_pt (list[float]): Target point to look at.
532
+ image_hw (tuple[int, int]): Image resolution.
533
+ fovy_deg (float): Field of view in degrees.
534
 
535
  Returns:
536
+ list[sapien.render.RenderCameraComponent]: List of cameras.
537
  """
538
  angle_step = 2 * np.pi / num_cameras
539
  world_up_vec = np.array([0.0, 0.0, 1.0])
 
579
 
580
 
581
  class FrankaPandaGrasper(object):
582
+ """Provides grasp planning and control for Franka Panda robot.
583
+
584
+ Attributes:
585
+ agent (BaseAgent): The robot agent.
586
+ robot: The robot instance.
587
+ control_freq (float): Control frequency.
588
+ control_timestep (float): Control timestep.
589
+ joint_vel_limits (float): Joint velocity limits.
590
+ joint_acc_limits (float): Joint acceleration limits.
591
+ finger_length (float): Length of gripper fingers.
592
+ planners: Motion planners for each environment.
593
+ """
594
+
595
  def __init__(
596
  self,
597
  agent: BaseAgent,
 
600
  joint_acc_limits: float = 1.0,
601
  finger_length: float = 0.025,
602
  ) -> None:
603
+ """Initialize the grasper."""
604
  self.agent = agent
605
  self.robot = agent.robot
606
  self.control_freq = control_freq
 
636
  gripper_state: Literal[-1, 1],
637
  n_step: int = 10,
638
  ) -> np.ndarray:
639
+ """Generate gripper control actions.
640
+
641
+ Args:
642
+ gripper_state (Literal[-1, 1]): Desired gripper state.
643
+ n_step (int): Number of steps.
644
+
645
+ Returns:
646
+ np.ndarray: Array of gripper actions.
647
+ """
648
  qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
649
  actions = []
650
  for _ in range(n_step):
 
663
  action_key: str = "position",
664
  env_idx: int = 0,
665
  ) -> np.ndarray:
666
+ """Plan and execute motion to a target pose.
667
+
668
+ Args:
669
+ pose (sapien.Pose): Target pose.
670
+ control_timestep (float): Control timestep.
671
+ gripper_state (Literal[-1, 1]): Desired gripper state.
672
+ use_point_cloud (bool): Use point cloud for planning.
673
+ n_max_step (int): Max number of steps.
674
+ action_key (str): Key for action in result.
675
+ env_idx (int): Environment index.
676
+
677
+ Returns:
678
+ np.ndarray: Array of actions to reach the pose.
679
+ """
680
  result = self.planners[env_idx].plan_qpos_to_pose(
681
  np.concatenate([pose.p, pose.q]),
682
  self.robot.get_qpos().cpu().numpy()[0],
 
714
  offset: tuple[float, float, float] = [0, 0, -0.05],
715
  env_idx: int = 0,
716
  ) -> np.ndarray:
717
+ """Compute grasp actions for a target actor.
718
+
719
+ Args:
720
+ actor (sapien.pysapien.Entity): Target actor to grasp.
721
+ reach_target_only (bool): Only reach the target pose if True.
722
+ offset (tuple[float, float, float]): Offset for reach pose.
723
+ env_idx (int): Environment index.
724
+
725
+ Returns:
726
+ np.ndarray: Array of grasp actions.
727
+ """
728
  physx_rigid = actor.components[1]
729
  mesh = get_component_mesh(physx_rigid, to_world_frame=True)
730
  obb = mesh.bounding_box_oriented
embodied_gen/utils/tags.py CHANGED
@@ -1 +1 @@
1
- VERSION = "v0.1.5"
 
1
+ VERSION = "v0.1.6"
embodied_gen/validators/aesthetic_predictor.py CHANGED
@@ -27,14 +27,22 @@ from PIL import Image
27
 
28
 
29
  class AestheticPredictor:
30
- """Aesthetic Score Predictor.
31
 
32
- Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
33
 
34
  Args:
35
- clip_model_dir (str): Path to the directory of the CLIP model.
36
- sac_model_path (str): Path to the pre-trained SAC model.
37
- device (str): Device to use for computation ("cuda" or "cpu").
 
 
 
 
 
 
 
 
38
  """
39
 
40
  def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
@@ -109,7 +117,7 @@ class AestheticPredictor:
109
  return model
110
 
111
  def predict(self, image_path):
112
- """Predict the aesthetic score for a given image.
113
 
114
  Args:
115
  image_path (str): Path to the image file.
 
27
 
28
 
29
  class AestheticPredictor:
30
+ """Aesthetic Score Predictor using CLIP and a pre-trained MLP.
31
 
32
+ Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`.
33
 
34
  Args:
35
+ clip_model_dir (str, optional): Path to CLIP model directory.
36
+ sac_model_path (str, optional): Path to SAC model weights.
37
+ device (str, optional): Device for computation ("cuda" or "cpu").
38
+
39
+ Example:
40
+ ```py
41
+ from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
42
+ predictor = AestheticPredictor(device="cuda")
43
+ score = predictor.predict("image.png")
44
+ print("Aesthetic score:", score)
45
+ ```
46
  """
47
 
48
  def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
 
117
  return model
118
 
119
  def predict(self, image_path):
120
+ """Predicts the aesthetic score for a given image.
121
 
122
  Args:
123
  image_path (str): Path to the image file.
embodied_gen/validators/quality_checkers.py CHANGED
@@ -40,6 +40,16 @@ __all__ = [
40
 
41
 
42
  class BaseChecker:
 
 
 
 
 
 
 
 
 
 
43
  def __init__(self, prompt: str = None, verbose: bool = False) -> None:
44
  self.prompt = prompt
45
  self.verbose = verbose
@@ -70,6 +80,15 @@ class BaseChecker:
70
  def validate(
71
  checkers: list["BaseChecker"], images_list: list[list[str]]
72
  ) -> list:
 
 
 
 
 
 
 
 
 
73
  assert len(checkers) == len(images_list)
74
  results = []
75
  overall_result = True
@@ -192,7 +211,7 @@ class ImageSegChecker(BaseChecker):
192
 
193
 
194
  class ImageAestheticChecker(BaseChecker):
195
- """A class for evaluating the aesthetic quality of images.
196
 
197
  Attributes:
198
  clip_model_dir (str): Path to the CLIP model directory.
@@ -200,6 +219,14 @@ class ImageAestheticChecker(BaseChecker):
200
  thresh (float): Threshold above which images are considered aesthetically acceptable.
201
  verbose (bool): Whether to print detailed log messages.
202
  predictor (AestheticPredictor): The model used to predict aesthetic scores.
 
 
 
 
 
 
 
 
203
  """
204
 
205
  def __init__(
@@ -227,6 +254,16 @@ class ImageAestheticChecker(BaseChecker):
227
 
228
 
229
  class SemanticConsistChecker(BaseChecker):
 
 
 
 
 
 
 
 
 
 
230
  def __init__(
231
  self,
232
  gpt_client: GPTclient,
@@ -276,6 +313,16 @@ class SemanticConsistChecker(BaseChecker):
276
 
277
 
278
  class TextGenAlignChecker(BaseChecker):
 
 
 
 
 
 
 
 
 
 
279
  def __init__(
280
  self,
281
  gpt_client: GPTclient,
@@ -489,6 +536,17 @@ class PanoHeightEstimator(object):
489
 
490
 
491
  class SemanticMatcher(BaseChecker):
 
 
 
 
 
 
 
 
 
 
 
492
  def __init__(
493
  self,
494
  gpt_client: GPTclient,
@@ -543,6 +601,17 @@ class SemanticMatcher(BaseChecker):
543
  def query(
544
  self, text: str, context: dict, rand: bool = True, params: dict = None
545
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
546
  match_list = self.gpt_client.query(
547
  self.prompt.format(context=context, text=text),
548
  params=params,
 
40
 
41
 
42
  class BaseChecker:
43
+ """Base class for quality checkers using GPT clients.
44
+
45
+ Provides a common interface for querying and validating responses.
46
+ Subclasses must implement the `query` method.
47
+
48
+ Attributes:
49
+ prompt (str): The prompt used for queries.
50
+ verbose (bool): Whether to enable verbose logging.
51
+ """
52
+
53
  def __init__(self, prompt: str = None, verbose: bool = False) -> None:
54
  self.prompt = prompt
55
  self.verbose = verbose
 
80
  def validate(
81
  checkers: list["BaseChecker"], images_list: list[list[str]]
82
  ) -> list:
83
+ """Validates a list of checkers against corresponding image lists.
84
+
85
+ Args:
86
+ checkers (list[BaseChecker]): List of checker instances.
87
+ images_list (list[list[str]]): List of image path lists.
88
+
89
+ Returns:
90
+ list: Validation results with overall outcome.
91
+ """
92
  assert len(checkers) == len(images_list)
93
  results = []
94
  overall_result = True
 
211
 
212
 
213
  class ImageAestheticChecker(BaseChecker):
214
+ """Evaluates the aesthetic quality of images using a CLIP-based predictor.
215
 
216
  Attributes:
217
  clip_model_dir (str): Path to the CLIP model directory.
 
219
  thresh (float): Threshold above which images are considered aesthetically acceptable.
220
  verbose (bool): Whether to print detailed log messages.
221
  predictor (AestheticPredictor): The model used to predict aesthetic scores.
222
+
223
+ Example:
224
+ ```py
225
+ from embodied_gen.validators.quality_checkers import ImageAestheticChecker
226
+ checker = ImageAestheticChecker(thresh=4.5)
227
+ flag, score = checker(["image1.png", "image2.png"])
228
+ print("Aesthetic OK:", flag, "Score:", score)
229
+ ```
230
  """
231
 
232
  def __init__(
 
254
 
255
 
256
  class SemanticConsistChecker(BaseChecker):
257
+ """Checks semantic consistency between text descriptions and segmented images.
258
+
259
+ Uses GPT to evaluate if the image matches the text in object type, geometry, and color.
260
+
261
+ Attributes:
262
+ gpt_client (GPTclient): GPT client for queries.
263
+ prompt (str): Prompt for consistency evaluation.
264
+ verbose (bool): Whether to enable verbose logging.
265
+ """
266
+
267
  def __init__(
268
  self,
269
  gpt_client: GPTclient,
 
313
 
314
 
315
  class TextGenAlignChecker(BaseChecker):
316
+ """Evaluates alignment between text prompts and generated 3D asset images.
317
+
318
+ Assesses if the rendered images match the text description in category and geometry.
319
+
320
+ Attributes:
321
+ gpt_client (GPTclient): GPT client for queries.
322
+ prompt (str): Prompt for alignment evaluation.
323
+ verbose (bool): Whether to enable verbose logging.
324
+ """
325
+
326
  def __init__(
327
  self,
328
  gpt_client: GPTclient,
 
536
 
537
 
538
  class SemanticMatcher(BaseChecker):
539
+ """Matches query text to semantically similar scene descriptions.
540
+
541
+ Uses GPT to find the most similar scene IDs from a dictionary.
542
+
543
+ Attributes:
544
+ gpt_client (GPTclient): GPT client for queries.
545
+ prompt (str): Prompt for semantic matching.
546
+ verbose (bool): Whether to enable verbose logging.
547
+ seed (int): Random seed for selection.
548
+ """
549
+
550
  def __init__(
551
  self,
552
  gpt_client: GPTclient,
 
601
  def query(
602
  self, text: str, context: dict, rand: bool = True, params: dict = None
603
  ) -> str:
604
+ """Queries for semantically similar scene IDs.
605
+
606
+ Args:
607
+ text (str): Query text.
608
+ context (dict): Dictionary of scene descriptions.
609
+ rand (bool, optional): Whether to randomly select from top matches.
610
+ params (dict, optional): Additional GPT parameters.
611
+
612
+ Returns:
613
+ str: Matched scene ID.
614
+ """
615
  match_list = self.gpt_client.query(
616
  self.prompt.format(context=context, text=text),
617
  params=params,
embodied_gen/validators/urdf_convertor.py CHANGED
@@ -80,6 +80,31 @@ URDF_TEMPLATE = """
80
 
81
 
82
  class URDFGenerator(object):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def __init__(
84
  self,
85
  gpt_client: GPTclient,
@@ -168,6 +193,14 @@ class URDFGenerator(object):
168
  self.rotate_xyzw = rotate_xyzw
169
 
170
  def parse_response(self, response: str) -> dict[str, any]:
 
 
 
 
 
 
 
 
171
  lines = response.split("\n")
172
  lines = [line.strip() for line in lines if line]
173
  category = lines[0].split(": ")[1]
@@ -207,11 +240,9 @@ class URDFGenerator(object):
207
 
208
  Args:
209
  input_mesh (str): Path to the input mesh file.
210
- output_dir (str): Directory to store the generated URDF
211
- and processed mesh.
212
- attr_dict (dict): Dictionary containing attributes like height,
213
- mass, and friction coefficients.
214
- output_name (str, optional): Name for the generated URDF and robot.
215
 
216
  Returns:
217
  str: Path to the generated URDF file.
@@ -336,6 +367,16 @@ class URDFGenerator(object):
336
  attr_root: str = ".//link/extra_info",
337
  attr_name: str = "scale",
338
  ) -> float:
 
 
 
 
 
 
 
 
 
 
339
  if not os.path.exists(urdf_path):
340
  raise FileNotFoundError(f"URDF file not found: {urdf_path}")
341
 
@@ -358,6 +399,13 @@ class URDFGenerator(object):
358
  def add_quality_tag(
359
  urdf_path: str, results: list, output_path: str = None
360
  ) -> None:
 
 
 
 
 
 
 
361
  if output_path is None:
362
  output_path = urdf_path
363
 
@@ -382,6 +430,14 @@ class URDFGenerator(object):
382
  logger.info(f"URDF files saved to {output_path}")
383
 
384
  def get_estimated_attributes(self, asset_attrs: dict):
 
 
 
 
 
 
 
 
385
  estimated_attrs = {
386
  "height": round(
387
  (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
@@ -403,6 +459,18 @@ class URDFGenerator(object):
403
  category: str = "unknown",
404
  **kwargs,
405
  ):
 
 
 
 
 
 
 
 
 
 
 
 
406
  if text_prompt is None or len(text_prompt) == 0:
407
  text_prompt = self.prompt_template
408
  text_prompt = text_prompt.format(category=category.lower())
 
80
 
81
 
82
  class URDFGenerator(object):
83
+ """Generates URDF files for 3D assets with physical and semantic attributes.
84
+
85
+ Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata.
86
+
87
+ Args:
88
+ gpt_client (GPTclient): GPT client for attribute estimation.
89
+ mesh_file_list (list[str], optional): Additional mesh files to copy.
90
+ prompt_template (str, optional): Prompt template for GPT queries.
91
+ attrs_name (list[str], optional): List of attribute names to include.
92
+ render_dir (str, optional): Directory for rendered images.
93
+ render_view_num (int, optional): Number of views to render.
94
+ decompose_convex (bool, optional): Whether to decompose mesh for collision.
95
+ rotate_xyzw (list[float], optional): Quaternion for mesh rotation.
96
+
97
+ Example:
98
+ ```py
99
+ from embodied_gen.validators.urdf_convertor import URDFGenerator
100
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
101
+
102
+ urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
103
+ urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir")
104
+ print("Generated URDF:", urdf_path)
105
+ ```
106
+ """
107
+
108
  def __init__(
109
  self,
110
  gpt_client: GPTclient,
 
193
  self.rotate_xyzw = rotate_xyzw
194
 
195
  def parse_response(self, response: str) -> dict[str, any]:
196
+ """Parses GPT response to extract asset attributes.
197
+
198
+ Args:
199
+ response (str): GPT response string.
200
+
201
+ Returns:
202
+ dict[str, any]: Parsed attributes.
203
+ """
204
  lines = response.split("\n")
205
  lines = [line.strip() for line in lines if line]
206
  category = lines[0].split(": ")[1]
 
240
 
241
  Args:
242
  input_mesh (str): Path to the input mesh file.
243
+ output_dir (str): Directory to store the generated URDF and mesh.
244
+ attr_dict (dict): Dictionary of asset attributes.
245
+ output_name (str, optional): Name for the URDF and robot.
 
 
246
 
247
  Returns:
248
  str: Path to the generated URDF file.
 
367
  attr_root: str = ".//link/extra_info",
368
  attr_name: str = "scale",
369
  ) -> float:
370
+ """Extracts an attribute value from a URDF file.
371
+
372
+ Args:
373
+ urdf_path (str): Path to the URDF file.
374
+ attr_root (str, optional): XML path to attribute root.
375
+ attr_name (str, optional): Attribute name.
376
+
377
+ Returns:
378
+ float: Attribute value, or None if not found.
379
+ """
380
  if not os.path.exists(urdf_path):
381
  raise FileNotFoundError(f"URDF file not found: {urdf_path}")
382
 
 
399
  def add_quality_tag(
400
  urdf_path: str, results: list, output_path: str = None
401
  ) -> None:
402
+ """Adds a quality tag to a URDF file.
403
+
404
+ Args:
405
+ urdf_path (str): Path to the URDF file.
406
+ results (list): List of [checker_name, result] pairs.
407
+ output_path (str, optional): Output file path.
408
+ """
409
  if output_path is None:
410
  output_path = urdf_path
411
 
 
430
  logger.info(f"URDF files saved to {output_path}")
431
 
432
  def get_estimated_attributes(self, asset_attrs: dict):
433
+ """Calculates estimated attributes from asset properties.
434
+
435
+ Args:
436
+ asset_attrs (dict): Asset attributes.
437
+
438
+ Returns:
439
+ dict: Estimated attributes (height, mass, mu, category).
440
+ """
441
  estimated_attrs = {
442
  "height": round(
443
  (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
 
459
  category: str = "unknown",
460
  **kwargs,
461
  ):
462
+ """Generates a URDF file for a mesh asset.
463
+
464
+ Args:
465
+ mesh_path (str): Path to mesh file.
466
+ output_root (str): Directory for outputs.
467
+ text_prompt (str, optional): Prompt for GPT.
468
+ category (str, optional): Asset category.
469
+ **kwargs: Additional attributes.
470
+
471
+ Returns:
472
+ str: Path to generated URDF file.
473
+ """
474
  if text_prompt is None or len(text_prompt) == 0:
475
  text_prompt = self.prompt_template
476
  text_prompt = text_prompt.format(category=category.lower())
thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py CHANGED
@@ -440,7 +440,7 @@ def to_glb(
440
  vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
 
442
  # bake texture
443
- observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
444
  masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
  extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
  intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
 
440
  vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
 
442
  # bake texture
443
+ observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=200)
444
  masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
  extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
  intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]