Spaces:
Build error
Build error
| import bisect | |
| import math | |
| import os | |
| from dataclasses import dataclass, field | |
| import cv2 | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset, IterableDataset | |
| import threestudio | |
| from threestudio import register | |
| from threestudio.data.uncond import ( | |
| RandomCameraDataModuleConfig, | |
| RandomCameraDataset, | |
| RandomCameraIterableDataset, | |
| ) | |
| from threestudio.utils.base import Updateable | |
| from threestudio.utils.config import parse_structured | |
| from threestudio.utils.misc import get_rank | |
| from threestudio.utils.ops import ( | |
| get_mvp_matrix, | |
| get_projection_matrix, | |
| get_ray_directions, | |
| get_rays, | |
| ) | |
| from threestudio.utils.typing import * | |
| class SingleImageDataModuleConfig: | |
| # height and width should be Union[int, List[int]] | |
| # but OmegaConf does not support Union of containers | |
| height: Any = 96 | |
| width: Any = 96 | |
| resolution_milestones: List[int] = field(default_factory=lambda: []) | |
| default_elevation_deg: float = 0.0 | |
| default_azimuth_deg: float = -180.0 | |
| default_camera_distance: float = 1.2 | |
| default_fovy_deg: float = 60.0 | |
| image_path: str = "" | |
| use_random_camera: bool = True | |
| random_camera: dict = field(default_factory=dict) | |
| rays_noise_scale: float = 2e-3 | |
| batch_size: int = 1 | |
| requires_depth: bool = False | |
| requires_normal: bool = False | |
| class SingleImageDataBase: | |
| def setup(self, cfg, split): | |
| self.split = split | |
| self.rank = get_rank() | |
| self.cfg: SingleImageDataModuleConfig = cfg | |
| if self.cfg.use_random_camera: | |
| random_camera_cfg = parse_structured( | |
| RandomCameraDataModuleConfig, self.cfg.get("random_camera", {}) | |
| ) | |
| if split == "train": | |
| self.random_pose_generator = RandomCameraIterableDataset( | |
| random_camera_cfg | |
| ) | |
| else: | |
| self.random_pose_generator = RandomCameraDataset( | |
| random_camera_cfg, split | |
| ) | |
| elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg]) | |
| azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg]) | |
| camera_distance = torch.FloatTensor([self.cfg.default_camera_distance]) | |
| elevation = elevation_deg * math.pi / 180 | |
| azimuth = azimuth_deg * math.pi / 180 | |
| camera_position: Float[Tensor, "1 3"] = torch.stack( | |
| [ | |
| camera_distance * torch.cos(elevation) * torch.cos(azimuth), | |
| camera_distance * torch.cos(elevation) * torch.sin(azimuth), | |
| camera_distance * torch.sin(elevation), | |
| ], | |
| dim=-1, | |
| ) | |
| center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position) | |
| up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None] | |
| light_position: Float[Tensor, "1 3"] = camera_position | |
| lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1) | |
| right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1) | |
| up = F.normalize(torch.cross(right, lookat), dim=-1) | |
| self.c2w: Float[Tensor, "1 3 4"] = torch.cat( | |
| [torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]], | |
| dim=-1, | |
| ) | |
| self.camera_position = camera_position | |
| self.light_position = light_position | |
| self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg | |
| self.camera_distance = camera_distance | |
| self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg])) | |
| self.heights: List[int] = ( | |
| [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height | |
| ) | |
| self.widths: List[int] = ( | |
| [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width | |
| ) | |
| assert len(self.heights) == len(self.widths) | |
| self.resolution_milestones: List[int] | |
| if len(self.heights) == 1 and len(self.widths) == 1: | |
| if len(self.cfg.resolution_milestones) > 0: | |
| threestudio.warn( | |
| "Ignoring resolution_milestones since height and width are not changing" | |
| ) | |
| self.resolution_milestones = [-1] | |
| else: | |
| assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 | |
| self.resolution_milestones = [-1] + self.cfg.resolution_milestones | |
| self.directions_unit_focals = [ | |
| get_ray_directions(H=height, W=width, focal=1.0) | |
| for (height, width) in zip(self.heights, self.widths) | |
| ] | |
| self.focal_lengths = [ | |
| 0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights | |
| ] | |
| self.height: int = self.heights[0] | |
| self.width: int = self.widths[0] | |
| self.directions_unit_focal = self.directions_unit_focals[0] | |
| self.focal_length = self.focal_lengths[0] | |
| self.set_rays() | |
| self.load_images() | |
| self.prev_height = self.height | |
| def set_rays(self): | |
| # get directions by dividing directions_unit_focal by focal length | |
| directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None] | |
| directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length | |
| rays_o, rays_d = get_rays( | |
| directions, self.c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale | |
| ) | |
| proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix( | |
| self.fovy, self.width / self.height, 0.1, 100.0 | |
| ) # FIXME: hard-coded near and far | |
| mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx) | |
| self.rays_o, self.rays_d = rays_o, rays_d | |
| self.mvp_mtx = mvp_mtx | |
| def load_images(self): | |
| # load image | |
| assert os.path.exists( | |
| self.cfg.image_path | |
| ), f"Could not find image {self.cfg.image_path}!" | |
| rgba = cv2.cvtColor( | |
| cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA | |
| ) | |
| rgba = ( | |
| cv2.resize( | |
| rgba, (self.width, self.height), interpolation=cv2.INTER_AREA | |
| ).astype(np.float32) | |
| / 255.0 | |
| ) | |
| rgb = rgba[..., :3] | |
| self.rgb: Float[Tensor, "1 H W 3"] = ( | |
| torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank) | |
| ) | |
| self.mask: Float[Tensor, "1 H W 1"] = ( | |
| torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank) | |
| ) | |
| print( | |
| f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}" | |
| ) | |
| # load depth | |
| if self.cfg.requires_depth: | |
| depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png") | |
| assert os.path.exists(depth_path) | |
| depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) | |
| depth = cv2.resize( | |
| depth, (self.width, self.height), interpolation=cv2.INTER_AREA | |
| ) | |
| self.depth: Float[Tensor, "1 H W 1"] = ( | |
| torch.from_numpy(depth.astype(np.float32) / 255.0) | |
| .unsqueeze(0) | |
| .to(self.rank) | |
| ) | |
| print( | |
| f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}" | |
| ) | |
| else: | |
| self.depth = None | |
| # load normal | |
| if self.cfg.requires_normal: | |
| normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png") | |
| assert os.path.exists(normal_path) | |
| normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) | |
| normal = cv2.resize( | |
| normal, (self.width, self.height), interpolation=cv2.INTER_AREA | |
| ) | |
| self.normal: Float[Tensor, "1 H W 3"] = ( | |
| torch.from_numpy(normal.astype(np.float32) / 255.0) | |
| .unsqueeze(0) | |
| .to(self.rank) | |
| ) | |
| print( | |
| f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}" | |
| ) | |
| else: | |
| self.normal = None | |
| def get_all_images(self): | |
| return self.rgb | |
| def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
| size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 | |
| self.height = self.heights[size_ind] | |
| if self.height == self.prev_height: | |
| return | |
| self.prev_height = self.height | |
| self.width = self.widths[size_ind] | |
| self.directions_unit_focal = self.directions_unit_focals[size_ind] | |
| self.focal_length = self.focal_lengths[size_ind] | |
| threestudio.debug(f"Training height: {self.height}, width: {self.width}") | |
| self.set_rays() | |
| self.load_images() | |
| class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable): | |
| def __init__(self, cfg: Any, split: str) -> None: | |
| super().__init__() | |
| self.setup(cfg, split) | |
| def collate(self, batch) -> Dict[str, Any]: | |
| batch = { | |
| "rays_o": self.rays_o, | |
| "rays_d": self.rays_d, | |
| "mvp_mtx": self.mvp_mtx, | |
| "camera_positions": self.camera_position, | |
| "light_positions": self.light_position, | |
| "elevation": self.elevation_deg, | |
| "azimuth": self.azimuth_deg, | |
| "camera_distances": self.camera_distance, | |
| "rgb": self.rgb, | |
| "ref_depth": self.depth, | |
| "ref_normal": self.normal, | |
| "mask": self.mask, | |
| "height": self.cfg.height, | |
| "width": self.cfg.width, | |
| } | |
| import pdb; pdb.set_trace() | |
| if self.cfg.use_random_camera: | |
| batch["random_camera"] = self.random_pose_generator.collate(None) | |
| return batch | |
| def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
| self.update_step_(epoch, global_step, on_load_weights) | |
| self.random_pose_generator.update_step(epoch, global_step, on_load_weights) | |
| def __iter__(self): | |
| while True: | |
| yield {} | |
| class SingleImageDataset(Dataset, SingleImageDataBase): | |
| def __init__(self, cfg: Any, split: str) -> None: | |
| super().__init__() | |
| self.setup(cfg, split) | |
| def __len__(self): | |
| return len(self.random_pose_generator) | |
| def __getitem__(self, index): | |
| return self.random_pose_generator[index] | |
| # if index == 0: | |
| # return { | |
| # 'rays_o': self.rays_o[0], | |
| # 'rays_d': self.rays_d[0], | |
| # 'mvp_mtx': self.mvp_mtx[0], | |
| # 'camera_positions': self.camera_position[0], | |
| # 'light_positions': self.light_position[0], | |
| # 'elevation': self.elevation_deg[0], | |
| # 'azimuth': self.azimuth_deg[0], | |
| # 'camera_distances': self.camera_distance[0], | |
| # 'rgb': self.rgb[0], | |
| # 'depth': self.depth[0], | |
| # 'mask': self.mask[0] | |
| # } | |
| # else: | |
| # return self.random_pose_generator[index - 1] | |
| class SingleImageDataModule(pl.LightningDataModule): | |
| cfg: SingleImageDataModuleConfig | |
| def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(SingleImageDataModuleConfig, cfg) | |
| def setup(self, stage=None) -> None: | |
| if stage in [None, "fit"]: | |
| self.train_dataset = SingleImageIterableDataset(self.cfg, "train") | |
| if stage in [None, "fit", "validate"]: | |
| self.val_dataset = SingleImageDataset(self.cfg, "val") | |
| if stage in [None, "test", "predict"]: | |
| self.test_dataset = SingleImageDataset(self.cfg, "test") | |
| def prepare_data(self): | |
| pass | |
| def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: | |
| return DataLoader( | |
| dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn | |
| ) | |
| def train_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.train_dataset, | |
| batch_size=self.cfg.batch_size, | |
| collate_fn=self.train_dataset.collate, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return self.general_loader(self.val_dataset, batch_size=1) | |
| def test_dataloader(self) -> DataLoader: | |
| return self.general_loader(self.test_dataset, batch_size=1) | |
| def predict_dataloader(self) -> DataLoader: | |
| return self.general_loader(self.test_dataset, batch_size=1) | |