Spaces:
Build error
Build error
| import os | |
| from dataclasses import dataclass, field | |
| import pytorch_lightning as pl | |
| import torch.nn.functional as F | |
| import threestudio | |
| from threestudio.models.exporters.base import Exporter, ExporterOutput | |
| from threestudio.systems.utils import parse_optimizer, parse_scheduler | |
| from threestudio.utils.base import Updateable, update_if_possible | |
| from threestudio.utils.config import parse_structured | |
| from threestudio.utils.misc import C, cleanup, get_device, load_module_weights | |
| from threestudio.utils.saving import SaverMixin | |
| from threestudio.utils.typing import * | |
| class BaseSystem(pl.LightningModule, Updateable, SaverMixin): | |
| class Config: | |
| loggers: dict = field(default_factory=dict) | |
| loss: dict = field(default_factory=dict) | |
| optimizer: dict = field(default_factory=dict) | |
| scheduler: Optional[dict] = None | |
| weights: Optional[str] = None | |
| weights_ignore_modules: Optional[List[str]] = None | |
| cleanup_after_validation_step: bool = False | |
| cleanup_after_test_step: bool = False | |
| cfg: Config | |
| def __init__(self, cfg, resumed=False) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| self._save_dir: Optional[str] = None | |
| self._resumed: bool = resumed | |
| self._resumed_eval: bool = False | |
| self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} | |
| if "loggers" in cfg: | |
| self.create_loggers(cfg.loggers) | |
| self.configure() | |
| if self.cfg.weights is not None: | |
| self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) | |
| self.post_configure() | |
| def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): | |
| state_dict, epoch, global_step = load_module_weights( | |
| weights, ignore_modules=ignore_modules, map_location="cpu" | |
| ) | |
| self.load_state_dict(state_dict, strict=False) | |
| # restore step-dependent states | |
| self.do_update_step(epoch, global_step, on_load_weights=True) | |
| def set_resume_status(self, current_epoch: int, global_step: int): | |
| # restore correct epoch and global step in eval | |
| self._resumed_eval = True | |
| self._resumed_eval_status["current_epoch"] = current_epoch | |
| self._resumed_eval_status["global_step"] = global_step | |
| def resumed(self): | |
| # whether from resumed checkpoint | |
| return self._resumed | |
| def true_global_step(self): | |
| if self._resumed_eval: | |
| return self._resumed_eval_status["global_step"] | |
| else: | |
| return self.global_step | |
| def true_current_epoch(self): | |
| if self._resumed_eval: | |
| return self._resumed_eval_status["current_epoch"] | |
| else: | |
| return self.current_epoch | |
| def configure(self) -> None: | |
| pass | |
| def post_configure(self) -> None: | |
| """ | |
| executed after weights are loaded | |
| """ | |
| pass | |
| def C(self, value: Any) -> float: | |
| return C(value, self.true_current_epoch, self.true_global_step) | |
| def configure_optimizers(self): | |
| optim = parse_optimizer(self.cfg.optimizer, self) | |
| ret = { | |
| "optimizer": optim, | |
| } | |
| if self.cfg.scheduler is not None: | |
| ret.update( | |
| { | |
| "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), | |
| } | |
| ) | |
| return ret | |
| def training_step(self, batch, batch_idx): | |
| raise NotImplementedError | |
| def validation_step(self, batch, batch_idx): | |
| raise NotImplementedError | |
| def on_validation_batch_end(self, outputs, batch, batch_idx): | |
| if self.cfg.cleanup_after_validation_step: | |
| # cleanup to save vram | |
| cleanup() | |
| def on_validation_epoch_end(self): | |
| raise NotImplementedError | |
| def test_step(self, batch, batch_idx): | |
| raise NotImplementedError | |
| def on_test_batch_end(self, outputs, batch, batch_idx): | |
| if self.cfg.cleanup_after_test_step: | |
| # cleanup to save vram | |
| cleanup() | |
| def on_test_epoch_end(self): | |
| pass | |
| def predict_step(self, batch, batch_idx): | |
| raise NotImplementedError | |
| def on_predict_batch_end(self, outputs, batch, batch_idx): | |
| if self.cfg.cleanup_after_test_step: | |
| # cleanup to save vram | |
| cleanup() | |
| def on_predict_epoch_end(self): | |
| pass | |
| def preprocess_data(self, batch, stage): | |
| pass | |
| """ | |
| Implementing on_after_batch_transfer of DataModule does the same. | |
| But on_after_batch_transfer does not support DP. | |
| """ | |
| def on_train_batch_start(self, batch, batch_idx, unused=0): | |
| self.preprocess_data(batch, "train") | |
| self.dataset = self.trainer.train_dataloader.dataset | |
| update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
| self.do_update_step(self.true_current_epoch, self.true_global_step) | |
| def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): | |
| self.preprocess_data(batch, "validation") | |
| self.dataset = self.trainer.val_dataloaders.dataset | |
| update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
| self.do_update_step(self.true_current_epoch, self.true_global_step) | |
| def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): | |
| self.preprocess_data(batch, "test") | |
| self.dataset = self.trainer.test_dataloaders.dataset | |
| update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
| self.do_update_step(self.true_current_epoch, self.true_global_step) | |
| def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): | |
| self.preprocess_data(batch, "predict") | |
| self.dataset = self.trainer.predict_dataloaders.dataset | |
| update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | |
| self.do_update_step(self.true_current_epoch, self.true_global_step) | |
| def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
| pass | |
| def on_before_optimizer_step(self, optimizer): | |
| """ | |
| # some gradient-related debugging goes here, example: | |
| from lightning.pytorch.utilities import grad_norm | |
| norms = grad_norm(self.geometry, norm_type=2) | |
| print(norms) | |
| """ | |
| pass | |
| class BaseLift3DSystem(BaseSystem): | |
| class Config(BaseSystem.Config): | |
| geometry_type: str = "" | |
| geometry: dict = field(default_factory=dict) | |
| geometry_convert_from: Optional[str] = None | |
| geometry_convert_inherit_texture: bool = False | |
| # used to override configurations of the previous geometry being converted from, | |
| # for example isosurface_threshold | |
| geometry_convert_override: dict = field(default_factory=dict) | |
| material_type: str = "" | |
| material: dict = field(default_factory=dict) | |
| background_type: str = "" | |
| background: dict = field(default_factory=dict) | |
| renderer_type: str = "" | |
| renderer: dict = field(default_factory=dict) | |
| guidance_type: str = "" | |
| guidance: dict = field(default_factory=dict) | |
| prompt_processor_type: str = "" | |
| prompt_processor: dict = field(default_factory=dict) | |
| # geometry export configurations, no need to specify in training | |
| exporter_type: str = "mesh-exporter" | |
| exporter: dict = field(default_factory=dict) | |
| cfg: Config | |
| def configure(self) -> None: | |
| if ( | |
| self.cfg.geometry_convert_from # from_coarse must be specified | |
| and not self.cfg.weights # not initialized from coarse when weights are specified | |
| and not self.resumed # not initialized from coarse when resumed from checkpoints | |
| ): | |
| threestudio.info("Initializing geometry from a given checkpoint ...") | |
| from threestudio.utils.config import load_config, parse_structured | |
| prev_cfg = load_config( | |
| os.path.join( | |
| os.path.dirname(self.cfg.geometry_convert_from), | |
| "../configs/parsed.yaml", | |
| ) | |
| ) # TODO: hard-coded relative path | |
| prev_system_cfg: BaseLift3DSystem.Config = parse_structured( | |
| self.Config, prev_cfg.system | |
| ) | |
| prev_geometry_cfg = prev_system_cfg.geometry | |
| prev_geometry_cfg.update(self.cfg.geometry_convert_override) | |
| prev_geometry = threestudio.find(prev_system_cfg.geometry_type)( | |
| prev_geometry_cfg | |
| ) | |
| state_dict, epoch, global_step = load_module_weights( | |
| self.cfg.geometry_convert_from, | |
| module_name="geometry", | |
| map_location="cpu", | |
| ) | |
| prev_geometry.load_state_dict(state_dict, strict=False) | |
| # restore step-dependent states | |
| prev_geometry.do_update_step(epoch, global_step, on_load_weights=True) | |
| # convert from coarse stage geometry | |
| prev_geometry = prev_geometry.to(get_device()) | |
| self.geometry = threestudio.find(self.cfg.geometry_type).create_from( | |
| prev_geometry, | |
| self.cfg.geometry, | |
| copy_net=self.cfg.geometry_convert_inherit_texture, | |
| ) | |
| del prev_geometry | |
| cleanup() | |
| else: | |
| self.geometry = threestudio.find(self.cfg.geometry_type)(self.cfg.geometry) | |
| self.material = threestudio.find(self.cfg.material_type)(self.cfg.material) | |
| self.background = threestudio.find(self.cfg.background_type)( | |
| self.cfg.background | |
| ) | |
| self.renderer = threestudio.find(self.cfg.renderer_type)( | |
| self.cfg.renderer, | |
| geometry=self.geometry, | |
| material=self.material, | |
| background=self.background, | |
| ) | |
| def on_fit_start(self) -> None: | |
| if self._save_dir is not None: | |
| threestudio.info(f"Validation results will be saved to {self._save_dir}") | |
| else: | |
| threestudio.warn( | |
| f"Saving directory not set for the system, visualization results will not be saved" | |
| ) | |
| def on_test_end(self) -> None: | |
| if self._save_dir is not None: | |
| threestudio.info(f"Test results saved to {self._save_dir}") | |
| def on_predict_start(self) -> None: | |
| self.exporter: Exporter = threestudio.find(self.cfg.exporter_type)( | |
| self.cfg.exporter, | |
| geometry=self.geometry, | |
| material=self.material, | |
| background=self.background, | |
| ) | |
| def predict_step(self, batch, batch_idx): | |
| if self.exporter.cfg.save_video: | |
| self.test_step(batch, batch_idx) | |
| def on_predict_epoch_end(self) -> None: | |
| if self.exporter.cfg.save_video: | |
| self.on_test_epoch_end() | |
| exporter_output: List[ExporterOutput] = self.exporter() | |
| for out in exporter_output: | |
| save_func_name = f"save_{out.save_type}" | |
| if not hasattr(self, save_func_name): | |
| raise ValueError(f"{save_func_name} not supported by the SaverMixin") | |
| save_func = getattr(self, save_func_name) | |
| save_func(f"it{self.true_global_step}-export/{out.save_name}", **out.params) | |
| def on_predict_end(self) -> None: | |
| if self._save_dir is not None: | |
| threestudio.info(f"Export assets saved to {self._save_dir}") | |
| def guidance_evaluation_save(self, comp_rgb, guidance_eval_out): | |
| B, size = comp_rgb.shape[:2] | |
| resize = lambda x: F.interpolate( | |
| x.permute(0, 3, 1, 2), (size, size), mode="bilinear", align_corners=False | |
| ).permute(0, 2, 3, 1) | |
| filename = f"it{self.true_global_step}-train.png" | |
| def merge12(x): | |
| return x.reshape(-1, *x.shape[2:]) | |
| self.save_image_grid( | |
| filename, | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": merge12(comp_rgb), | |
| "kwargs": {"data_format": "HWC"}, | |
| }, | |
| ] | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": merge12(resize(guidance_eval_out["imgs_noisy"])), | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| ) | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": merge12(resize(guidance_eval_out["imgs_1step"])), | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| ) | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": merge12(resize(guidance_eval_out["imgs_1orig"])), | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| ) | |
| + ( | |
| [ | |
| { | |
| "type": "rgb", | |
| "img": merge12(resize(guidance_eval_out["imgs_final"])), | |
| "kwargs": {"data_format": "HWC"}, | |
| } | |
| ] | |
| ), | |
| name="train_step", | |
| step=self.true_global_step, | |
| texts=guidance_eval_out["texts"], | |
| ) | |