Spaces:
Build error
Build error
| import random | |
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import threestudio | |
| from threestudio.models.materials.base import BaseMaterial | |
| from threestudio.utils.typing import * | |
| class StableDiffusionLatentAdapterMaterial(BaseMaterial): | |
| class Config(BaseMaterial.Config): | |
| pass | |
| cfg: Config | |
| def configure(self) -> None: | |
| adapter = nn.Parameter( | |
| torch.as_tensor( | |
| [ | |
| # R G B | |
| [0.298, 0.207, 0.208], # L1 | |
| [0.187, 0.286, 0.173], # L2 | |
| [-0.158, 0.189, 0.264], # L3 | |
| [-0.184, -0.271, -0.473], # L4 | |
| ] | |
| ) | |
| ) | |
| self.register_parameter("adapter", adapter) | |
| def forward( | |
| self, features: Float[Tensor, "B ... 4"], **kwargs | |
| ) -> Float[Tensor, "B ... 3"]: | |
| assert features.shape[-1] == 4 | |
| color = features @ self.adapter | |
| color = (color + 1) / 2 | |
| color = color.clamp(0.0, 1.0) | |
| return color | |