xinjie.wang commited on
Commit
a8ea627
·
1 Parent(s): b800513
common.py CHANGED
@@ -34,7 +34,7 @@ from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
36
  from embodied_gen.data.differentiable_render import entrypoint as render_api
37
- from embodied_gen.data.utils import trellis_preprocess, zip_files
38
  from embodied_gen.models.delight_model import DelightingModel
39
  from embodied_gen.models.gs_model import GaussianOperator
40
  from embodied_gen.models.segment_model import (
@@ -208,7 +208,7 @@ def preprocess_image_fn(
208
  elif isinstance(image, np.ndarray):
209
  image = Image.fromarray(image)
210
 
211
- image_cache = image.copy().resize((512, 512))
212
 
213
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
214
  image = bg_remover(image)
@@ -224,7 +224,7 @@ def preprocess_sam_image_fn(
224
  image = Image.fromarray(image)
225
 
226
  sam_image = SAM_PREDICTOR.preprocess_image(image)
227
- image_cache = Image.fromarray(sam_image).resize((512, 512))
228
  SAM_PREDICTOR.predictor.set_image(sam_image)
229
 
230
  return sam_image, image_cache
 
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
36
  from embodied_gen.data.differentiable_render import entrypoint as render_api
37
+ from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files
38
  from embodied_gen.models.delight_model import DelightingModel
39
  from embodied_gen.models.gs_model import GaussianOperator
40
  from embodied_gen.models.segment_model import (
 
208
  elif isinstance(image, np.ndarray):
209
  image = Image.fromarray(image)
210
 
211
+ image_cache = resize_pil(image.copy(), 1024)
212
 
213
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
214
  image = bg_remover(image)
 
224
  image = Image.fromarray(image)
225
 
226
  sam_image = SAM_PREDICTOR.preprocess_image(image)
227
+ image_cache = sam_image.copy()
228
  SAM_PREDICTOR.predictor.set_image(sam_image)
229
 
230
  return sam_image, image_cache
embodied_gen/data/backproject_v3.py CHANGED
@@ -14,7 +14,7 @@
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
-
18
  import argparse
19
  import logging
20
  import math
@@ -425,6 +425,7 @@ def parse_args():
425
  return args
426
 
427
 
 
428
  def entrypoint(
429
  delight_model: DelightingModel = None,
430
  imagesr_model: ImageRealESRGAN = None,
 
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
+ import os
18
  import argparse
19
  import logging
20
  import math
 
425
  return args
426
 
427
 
428
+ @spaces.GPU
429
  def entrypoint(
430
  delight_model: DelightingModel = None,
431
  imagesr_model: ImageRealESRGAN = None,
embodied_gen/scripts/imageto3d.py CHANGED
@@ -26,12 +26,14 @@ import numpy as np
26
  import torch
27
  import trimesh
28
  from PIL import Image
29
- from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
30
  from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
- from embodied_gen.models.delight_model import DelightingModel
 
32
  from embodied_gen.models.gs_model import GaussianOperator
33
  from embodied_gen.models.segment_model import RembgRemover
34
- from embodied_gen.models.sr_model import ImageRealESRGAN
 
35
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
36
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
37
  from embodied_gen.utils.log import logger
@@ -59,8 +61,8 @@ os.environ["SPCONV_ALGO"] = "native"
59
  random.seed(0)
60
 
61
  logger.info("Loading Image3D Models...")
62
- DELIGHT = DelightingModel()
63
- IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
64
  RBG_REMOVER = RembgRemover()
65
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
66
  "microsoft/TRELLIS-image-large"
@@ -108,9 +110,7 @@ def parse_args():
108
  default=2,
109
  )
110
  parser.add_argument("--disable_decompose_convex", action="store_true")
111
- parser.add_argument(
112
- "--texture_wh", type=int, nargs=2, default=[2048, 2048]
113
- )
114
  args, unknown = parser.parse_known_args()
115
 
116
  return args
@@ -248,16 +248,14 @@ def entrypoint(**kwargs):
248
  mesh.export(mesh_obj_path)
249
 
250
  mesh = backproject_api(
251
- delight_model=DELIGHT,
252
- imagesr_model=IMAGESR_MODEL,
253
- color_path=color_path,
254
  mesh_path=mesh_obj_path,
255
  output_path=mesh_obj_path,
256
  skip_fix_mesh=False,
257
- delight=True,
258
- texture_wh=args.texture_wh,
259
- elevation=[20, -10, 60, -50],
260
- num_images=12,
261
  )
262
 
263
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
 
26
  import torch
27
  import trimesh
28
  from PIL import Image
29
+ from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
30
  from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
+
32
+ # from embodied_gen.models.delight_model import DelightingModel
33
  from embodied_gen.models.gs_model import GaussianOperator
34
  from embodied_gen.models.segment_model import RembgRemover
35
+
36
+ # from embodied_gen.models.sr_model import ImageRealESRGAN
37
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
38
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
39
  from embodied_gen.utils.log import logger
 
61
  random.seed(0)
62
 
63
  logger.info("Loading Image3D Models...")
64
+ # DELIGHT = DelightingModel()
65
+ # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
66
  RBG_REMOVER = RembgRemover()
67
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
68
  "microsoft/TRELLIS-image-large"
 
110
  default=2,
111
  )
112
  parser.add_argument("--disable_decompose_convex", action="store_true")
113
+ parser.add_argument("--texture_size", type=int, default=2048)
 
 
114
  args, unknown = parser.parse_known_args()
115
 
116
  return args
 
248
  mesh.export(mesh_obj_path)
249
 
250
  mesh = backproject_api(
251
+ # delight_model=DELIGHT,
252
+ # imagesr_model=IMAGESR_MODEL,
253
+ gs_path=aligned_gs_path,
254
  mesh_path=mesh_obj_path,
255
  output_path=mesh_obj_path,
256
  skip_fix_mesh=False,
257
+ texture_size=args.texture_size,
258
+ delight=False,
 
 
259
  )
260
 
261
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")