Spaces:
Running
Running
+ HF models now built with config not pickle
Browse files- app.py +8 -4
- msma.py +22 -12
- networks_edm2.py +318 -0
app.py
CHANGED
|
@@ -11,7 +11,12 @@ import torch
|
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
from safetensors.torch import load_file
|
| 13 |
|
| 14 |
-
from msma import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@cache
|
|
@@ -32,7 +37,6 @@ def load_model_from_hub(preset, device):
|
|
| 32 |
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 33 |
cache_dir = os.environ["DNNLIB_CACHE_DIR"]
|
| 34 |
|
| 35 |
-
scorenet = build_model_from_pickle(preset)
|
| 36 |
|
| 37 |
for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
|
| 38 |
cached_fname = hf_hub_download(
|
|
@@ -49,10 +53,10 @@ def load_model_from_hub(preset, device):
|
|
| 49 |
print("Loaded:", model_params)
|
| 50 |
|
| 51 |
hf_checkpoint = f"{modeldir}/model.safetensors"
|
| 52 |
-
model =
|
| 53 |
model.load_state_dict(load_file(hf_checkpoint), strict=True)
|
| 54 |
model = model.eval().requires_grad_(False)
|
| 55 |
-
|
| 56 |
return model, modeldir
|
| 57 |
|
| 58 |
|
|
|
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
from safetensors.torch import load_file
|
| 13 |
|
| 14 |
+
from msma import (
|
| 15 |
+
ScoreFlow,
|
| 16 |
+
build_model_from_config,
|
| 17 |
+
build_model_from_pickle,
|
| 18 |
+
config_presets,
|
| 19 |
+
)
|
| 20 |
|
| 21 |
|
| 22 |
@cache
|
|
|
|
| 37 |
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 38 |
cache_dir = os.environ["DNNLIB_CACHE_DIR"]
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
|
| 42 |
cached_fname = hf_hub_download(
|
|
|
|
| 53 |
print("Loaded:", model_params)
|
| 54 |
|
| 55 |
hf_checkpoint = f"{modeldir}/model.safetensors"
|
| 56 |
+
model = build_model_from_config(model_params)
|
| 57 |
model.load_state_dict(load_file(hf_checkpoint), strict=True)
|
| 58 |
model = model.eval().requires_grad_(False)
|
| 59 |
+
model.to(device)
|
| 60 |
return model, modeldir
|
| 61 |
|
| 62 |
|
msma.py
CHANGED
|
@@ -21,6 +21,7 @@ from tqdm import tqdm
|
|
| 21 |
import dnnlib
|
| 22 |
from dataset import ImageFolderDataset
|
| 23 |
from flowutils import PatchFlow, sanitize_locals
|
|
|
|
| 24 |
|
| 25 |
DEVICE: Literal["cuda", "cpu"] = "cpu"
|
| 26 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
|
@@ -122,6 +123,14 @@ class ScoreFlow(torch.nn.Module):
|
|
| 122 |
return self.flow(x_scores)
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
|
| 126 |
netpath = config_presets[preset]
|
| 127 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
|
@@ -196,13 +205,13 @@ def cmdline():
|
|
| 196 |
def common_args(func):
|
| 197 |
@wraps(func)
|
| 198 |
@click.option(
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
)
|
| 206 |
@click.option(
|
| 207 |
"--dataset_path",
|
| 208 |
help="Path to the dataset",
|
|
@@ -222,7 +231,8 @@ def common_args(func):
|
|
| 222 |
|
| 223 |
return wrapper
|
| 224 |
|
| 225 |
-
|
|
|
|
| 226 |
@click.option(
|
| 227 |
"--gridsearch",
|
| 228 |
help="Whether to use a grid search on a number of components to find the best fit",
|
|
@@ -365,7 +375,7 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
|
| 365 |
train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
|
| 366 |
)
|
| 367 |
testiter = torch.utils.data.DataLoader(
|
| 368 |
-
val_ds, batch_size=batch_size*2, num_workers=4, prefetch_factor=2
|
| 369 |
)
|
| 370 |
|
| 371 |
scorenet = build_model_from_pickle(preset)
|
|
@@ -392,10 +402,10 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
|
| 392 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
| 393 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
| 394 |
|
| 395 |
-
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
| 396 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 397 |
|
| 398 |
-
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 399 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 400 |
|
| 401 |
# totaliters = int(epochs * train_len)
|
|
@@ -463,7 +473,7 @@ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
|
| 463 |
|
| 464 |
# Save final model
|
| 465 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
| 466 |
-
|
| 467 |
writer.close()
|
| 468 |
|
| 469 |
|
|
|
|
| 21 |
import dnnlib
|
| 22 |
from dataset import ImageFolderDataset
|
| 23 |
from flowutils import PatchFlow, sanitize_locals
|
| 24 |
+
from networks_edm2 import Precond
|
| 25 |
|
| 26 |
DEVICE: Literal["cuda", "cpu"] = "cpu"
|
| 27 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
|
|
|
| 123 |
return self.flow(x_scores)
|
| 124 |
|
| 125 |
|
| 126 |
+
def build_model_from_config(model_params):
|
| 127 |
+
net = Precond(**model_params["EDMNet"])
|
| 128 |
+
scorenet = EDMScorer(net=net, **model_params["EDMScorer"])
|
| 129 |
+
scoreflow = ScoreFlow(scorenet=scorenet, **model_params["PatchFlow"])
|
| 130 |
+
print("Built model from config")
|
| 131 |
+
return scoreflow
|
| 132 |
+
|
| 133 |
+
|
| 134 |
def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"):
|
| 135 |
netpath = config_presets[preset]
|
| 136 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
|
|
|
| 205 |
def common_args(func):
|
| 206 |
@wraps(func)
|
| 207 |
@click.option(
|
| 208 |
+
"--preset",
|
| 209 |
+
help="Configuration preset",
|
| 210 |
+
metavar="STR",
|
| 211 |
+
type=str,
|
| 212 |
+
default="edm2-img64-s-fid",
|
| 213 |
+
show_default=True,
|
| 214 |
+
)
|
| 215 |
@click.option(
|
| 216 |
"--dataset_path",
|
| 217 |
help="Path to the dataset",
|
|
|
|
| 231 |
|
| 232 |
return wrapper
|
| 233 |
|
| 234 |
+
|
| 235 |
+
@cmdline.command("train-gmm")
|
| 236 |
@click.option(
|
| 237 |
"--gridsearch",
|
| 238 |
help="Whether to use a grid search on a number of components to find the best fit",
|
|
|
|
| 375 |
train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
|
| 376 |
)
|
| 377 |
testiter = torch.utils.data.DataLoader(
|
| 378 |
+
val_ds, batch_size=batch_size * 2, num_workers=4, prefetch_factor=2
|
| 379 |
)
|
| 380 |
|
| 381 |
scorenet = build_model_from_pickle(preset)
|
|
|
|
| 402 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
| 403 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
| 404 |
|
| 405 |
+
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
| 406 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 407 |
|
| 408 |
+
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 409 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 410 |
|
| 411 |
# totaliters = int(epochs * train_len)
|
|
|
|
| 473 |
|
| 474 |
# Save final model
|
| 475 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
| 476 |
+
|
| 477 |
writer.close()
|
| 478 |
|
| 479 |
|
networks_edm2.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Improved diffusion model architecture proposed in the paper
|
| 9 |
+
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch_utils import persistence
|
| 14 |
+
from torch_utils import misc
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
# Normalize given tensor to unit magnitude with respect to the given
|
| 18 |
+
# dimensions. Default = all dimensions except the first.
|
| 19 |
+
|
| 20 |
+
def normalize(x, dim=None, eps=1e-4):
|
| 21 |
+
if dim is None:
|
| 22 |
+
dim = list(range(1, x.ndim))
|
| 23 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 24 |
+
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
| 25 |
+
return x / norm.to(x.dtype)
|
| 26 |
+
|
| 27 |
+
#----------------------------------------------------------------------------
|
| 28 |
+
# Upsample or downsample the given tensor with the given filter,
|
| 29 |
+
# or keep it as is.
|
| 30 |
+
|
| 31 |
+
def resample(x, f=[1,1], mode='keep'):
|
| 32 |
+
if mode == 'keep':
|
| 33 |
+
return x
|
| 34 |
+
f = np.float32(f)
|
| 35 |
+
assert f.ndim == 1 and len(f) % 2 == 0
|
| 36 |
+
pad = (len(f) - 1) // 2
|
| 37 |
+
f = f / f.sum()
|
| 38 |
+
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
| 39 |
+
f = misc.const_like(x, f)
|
| 40 |
+
c = x.shape[1]
|
| 41 |
+
if mode == 'down':
|
| 42 |
+
return torch.nn.functional.conv2d(x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 43 |
+
assert mode == 'up'
|
| 44 |
+
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 45 |
+
|
| 46 |
+
#----------------------------------------------------------------------------
|
| 47 |
+
# Magnitude-preserving SiLU (Equation 81).
|
| 48 |
+
|
| 49 |
+
def mp_silu(x):
|
| 50 |
+
return torch.nn.functional.silu(x) / 0.596
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
# Magnitude-preserving sum (Equation 88).
|
| 54 |
+
|
| 55 |
+
def mp_sum(a, b, t=0.5):
|
| 56 |
+
return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
|
| 57 |
+
|
| 58 |
+
#----------------------------------------------------------------------------
|
| 59 |
+
# Magnitude-preserving concatenation (Equation 103).
|
| 60 |
+
|
| 61 |
+
def mp_cat(a, b, dim=1, t=0.5):
|
| 62 |
+
Na = a.shape[dim]
|
| 63 |
+
Nb = b.shape[dim]
|
| 64 |
+
C = np.sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2))
|
| 65 |
+
wa = C / np.sqrt(Na) * (1 - t)
|
| 66 |
+
wb = C / np.sqrt(Nb) * t
|
| 67 |
+
return torch.cat([wa * a , wb * b], dim=dim)
|
| 68 |
+
|
| 69 |
+
#----------------------------------------------------------------------------
|
| 70 |
+
# Magnitude-preserving Fourier features (Equation 75).
|
| 71 |
+
|
| 72 |
+
@persistence.persistent_class
|
| 73 |
+
class MPFourier(torch.nn.Module):
|
| 74 |
+
def __init__(self, num_channels, bandwidth=1):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
|
| 77 |
+
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
y = x.to(torch.float32)
|
| 81 |
+
y = y.ger(self.freqs.to(torch.float32))
|
| 82 |
+
y = y + self.phases.to(torch.float32)
|
| 83 |
+
y = y.cos() * np.sqrt(2)
|
| 84 |
+
return y.to(x.dtype)
|
| 85 |
+
|
| 86 |
+
#----------------------------------------------------------------------------
|
| 87 |
+
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
| 88 |
+
# with force weight normalization (Equation 66).
|
| 89 |
+
|
| 90 |
+
@persistence.persistent_class
|
| 91 |
+
class MPConv(torch.nn.Module):
|
| 92 |
+
def __init__(self, in_channels, out_channels, kernel):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.out_channels = out_channels
|
| 95 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
| 96 |
+
|
| 97 |
+
def forward(self, x, gain=1):
|
| 98 |
+
w = self.weight.to(torch.float32)
|
| 99 |
+
if self.training:
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
self.weight.copy_(normalize(w)) # forced weight normalization
|
| 102 |
+
w = normalize(w) # traditional weight normalization
|
| 103 |
+
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
|
| 104 |
+
w = w.to(x.dtype)
|
| 105 |
+
if w.ndim == 2:
|
| 106 |
+
return x @ w.t()
|
| 107 |
+
assert w.ndim == 4
|
| 108 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
|
| 109 |
+
|
| 110 |
+
#----------------------------------------------------------------------------
|
| 111 |
+
# U-Net encoder/decoder block with optional self-attention (Figure 21).
|
| 112 |
+
|
| 113 |
+
@persistence.persistent_class
|
| 114 |
+
class Block(torch.nn.Module):
|
| 115 |
+
def __init__(self,
|
| 116 |
+
in_channels, # Number of input channels.
|
| 117 |
+
out_channels, # Number of output channels.
|
| 118 |
+
emb_channels, # Number of embedding channels.
|
| 119 |
+
flavor = 'enc', # Flavor: 'enc' or 'dec'.
|
| 120 |
+
resample_mode = 'keep', # Resampling: 'keep', 'up', or 'down'.
|
| 121 |
+
resample_filter = [1,1], # Resampling filter.
|
| 122 |
+
attention = False, # Include self-attention?
|
| 123 |
+
channels_per_head = 64, # Number of channels per attention head.
|
| 124 |
+
dropout = 0, # Dropout probability.
|
| 125 |
+
res_balance = 0.3, # Balance between main branch (0) and residual branch (1).
|
| 126 |
+
attn_balance = 0.3, # Balance between main branch (0) and self-attention (1).
|
| 127 |
+
clip_act = 256, # Clip output activations. None = do not clip.
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.out_channels = out_channels
|
| 131 |
+
self.flavor = flavor
|
| 132 |
+
self.resample_filter = resample_filter
|
| 133 |
+
self.resample_mode = resample_mode
|
| 134 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
| 135 |
+
self.dropout = dropout
|
| 136 |
+
self.res_balance = res_balance
|
| 137 |
+
self.attn_balance = attn_balance
|
| 138 |
+
self.clip_act = clip_act
|
| 139 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
| 140 |
+
self.conv_res0 = MPConv(out_channels if flavor == 'enc' else in_channels, out_channels, kernel=[3,3])
|
| 141 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
|
| 142 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3,3])
|
| 143 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=[1,1]) if in_channels != out_channels else None
|
| 144 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=[1,1]) if self.num_heads != 0 else None
|
| 145 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=[1,1]) if self.num_heads != 0 else None
|
| 146 |
+
|
| 147 |
+
def forward(self, x, emb):
|
| 148 |
+
# Main branch.
|
| 149 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
| 150 |
+
if self.flavor == 'enc':
|
| 151 |
+
if self.conv_skip is not None:
|
| 152 |
+
x = self.conv_skip(x)
|
| 153 |
+
x = normalize(x, dim=1) # pixel norm
|
| 154 |
+
|
| 155 |
+
# Residual branch.
|
| 156 |
+
y = self.conv_res0(mp_silu(x))
|
| 157 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
| 158 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
| 159 |
+
if self.training and self.dropout != 0:
|
| 160 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
| 161 |
+
y = self.conv_res1(y)
|
| 162 |
+
|
| 163 |
+
# Connect the branches.
|
| 164 |
+
if self.flavor == 'dec' and self.conv_skip is not None:
|
| 165 |
+
x = self.conv_skip(x)
|
| 166 |
+
x = mp_sum(x, y, t=self.res_balance)
|
| 167 |
+
|
| 168 |
+
# Self-attention.
|
| 169 |
+
# Note: torch.nn.functional.scaled_dot_product_attention() could be used here,
|
| 170 |
+
# but we haven't done sufficient testing to verify that it produces identical results.
|
| 171 |
+
if self.num_heads != 0:
|
| 172 |
+
y = self.attn_qkv(x)
|
| 173 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
| 174 |
+
q, k, v = normalize(y, dim=2).unbind(3) # pixel norm & split
|
| 175 |
+
w = torch.einsum('nhcq,nhck->nhqk', q, k / np.sqrt(q.shape[2])).softmax(dim=3)
|
| 176 |
+
y = torch.einsum('nhqk,nhck->nhcq', w, v)
|
| 177 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
| 178 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
| 179 |
+
|
| 180 |
+
# Clip activations.
|
| 181 |
+
if self.clip_act is not None:
|
| 182 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
#----------------------------------------------------------------------------
|
| 186 |
+
# EDM2 U-Net model (Figure 21).
|
| 187 |
+
|
| 188 |
+
@persistence.persistent_class
|
| 189 |
+
class UNet(torch.nn.Module):
|
| 190 |
+
def __init__(self,
|
| 191 |
+
img_resolution, # Image resolution.
|
| 192 |
+
img_channels, # Image channels.
|
| 193 |
+
label_dim, # Class label dimensionality. 0 = unconditional.
|
| 194 |
+
model_channels = 192, # Base multiplier for the number of channels.
|
| 195 |
+
channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
|
| 196 |
+
channel_mult_noise = None, # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
|
| 197 |
+
channel_mult_emb = None, # Multiplier for final embedding dimensionality. None = select based on channel_mult.
|
| 198 |
+
num_blocks = 3, # Number of residual blocks per resolution.
|
| 199 |
+
attn_resolutions = [16,8], # List of resolutions with self-attention.
|
| 200 |
+
label_balance = 0.5, # Balance between noise embedding (0) and class embedding (1).
|
| 201 |
+
concat_balance = 0.5, # Balance between skip connections (0) and main path (1).
|
| 202 |
+
**block_kwargs, # Arguments for Block.
|
| 203 |
+
):
|
| 204 |
+
super().__init__()
|
| 205 |
+
cblock = [model_channels * x for x in channel_mult]
|
| 206 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
| 207 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
| 208 |
+
self.label_balance = label_balance
|
| 209 |
+
self.concat_balance = concat_balance
|
| 210 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
| 211 |
+
|
| 212 |
+
# Embedding.
|
| 213 |
+
self.emb_fourier = MPFourier(cnoise)
|
| 214 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=[])
|
| 215 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=[]) if label_dim != 0 else None
|
| 216 |
+
|
| 217 |
+
# Encoder.
|
| 218 |
+
self.enc = torch.nn.ModuleDict()
|
| 219 |
+
cout = img_channels + 1
|
| 220 |
+
for level, channels in enumerate(cblock):
|
| 221 |
+
res = img_resolution >> level
|
| 222 |
+
if level == 0:
|
| 223 |
+
cin = cout
|
| 224 |
+
cout = channels
|
| 225 |
+
self.enc[f'{res}x{res}_conv'] = MPConv(cin, cout, kernel=[3,3])
|
| 226 |
+
else:
|
| 227 |
+
self.enc[f'{res}x{res}_down'] = Block(cout, cout, cemb, flavor='enc', resample_mode='down', **block_kwargs)
|
| 228 |
+
for idx in range(num_blocks):
|
| 229 |
+
cin = cout
|
| 230 |
+
cout = channels
|
| 231 |
+
self.enc[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='enc', attention=(res in attn_resolutions), **block_kwargs)
|
| 232 |
+
|
| 233 |
+
# Decoder.
|
| 234 |
+
self.dec = torch.nn.ModuleDict()
|
| 235 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 236 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
| 237 |
+
res = img_resolution >> level
|
| 238 |
+
if level == len(cblock) - 1:
|
| 239 |
+
self.dec[f'{res}x{res}_in0'] = Block(cout, cout, cemb, flavor='dec', attention=True, **block_kwargs)
|
| 240 |
+
self.dec[f'{res}x{res}_in1'] = Block(cout, cout, cemb, flavor='dec', **block_kwargs)
|
| 241 |
+
else:
|
| 242 |
+
self.dec[f'{res}x{res}_up'] = Block(cout, cout, cemb, flavor='dec', resample_mode='up', **block_kwargs)
|
| 243 |
+
for idx in range(num_blocks + 1):
|
| 244 |
+
cin = cout + skips.pop()
|
| 245 |
+
cout = channels
|
| 246 |
+
self.dec[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='dec', attention=(res in attn_resolutions), **block_kwargs)
|
| 247 |
+
self.out_conv = MPConv(cout, img_channels, kernel=[3,3])
|
| 248 |
+
|
| 249 |
+
def forward(self, x, noise_labels, class_labels):
|
| 250 |
+
# Embedding.
|
| 251 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
| 252 |
+
if self.emb_label is not None:
|
| 253 |
+
emb = mp_sum(emb, self.emb_label(class_labels * np.sqrt(class_labels.shape[1])), t=self.label_balance)
|
| 254 |
+
emb = mp_silu(emb)
|
| 255 |
+
|
| 256 |
+
# Encoder.
|
| 257 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
| 258 |
+
skips = []
|
| 259 |
+
for name, block in self.enc.items():
|
| 260 |
+
x = block(x) if 'conv' in name else block(x, emb)
|
| 261 |
+
skips.append(x)
|
| 262 |
+
|
| 263 |
+
# Decoder.
|
| 264 |
+
for name, block in self.dec.items():
|
| 265 |
+
if 'block' in name:
|
| 266 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
| 267 |
+
x = block(x, emb)
|
| 268 |
+
x = self.out_conv(x, gain=self.out_gain)
|
| 269 |
+
return x
|
| 270 |
+
|
| 271 |
+
#----------------------------------------------------------------------------
|
| 272 |
+
# Preconditioning and uncertainty estimation.
|
| 273 |
+
|
| 274 |
+
@persistence.persistent_class
|
| 275 |
+
class Precond(torch.nn.Module):
|
| 276 |
+
def __init__(self,
|
| 277 |
+
img_resolution, # Image resolution.
|
| 278 |
+
img_channels, # Image channels.
|
| 279 |
+
label_dim, # Class label dimensionality. 0 = unconditional.
|
| 280 |
+
use_fp16 = True, # Run the model at FP16 precision?
|
| 281 |
+
sigma_data = 0.5, # Expected standard deviation of the training data.
|
| 282 |
+
logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation.
|
| 283 |
+
**unet_kwargs, # Keyword arguments for UNet.
|
| 284 |
+
):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.img_resolution = img_resolution
|
| 287 |
+
self.img_channels = img_channels
|
| 288 |
+
self.label_dim = label_dim
|
| 289 |
+
self.use_fp16 = use_fp16
|
| 290 |
+
self.sigma_data = sigma_data
|
| 291 |
+
self.unet = UNet(img_resolution=img_resolution, img_channels=img_channels, label_dim=label_dim, **unet_kwargs)
|
| 292 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
| 293 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
|
| 294 |
+
|
| 295 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, return_logvar=False, **unet_kwargs):
|
| 296 |
+
x = x.to(torch.float32)
|
| 297 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 298 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 299 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 300 |
+
|
| 301 |
+
# Preconditioning weights.
|
| 302 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
| 303 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
| 304 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
| 305 |
+
c_noise = sigma.flatten().log() / 4
|
| 306 |
+
|
| 307 |
+
# Run the model.
|
| 308 |
+
x_in = (c_in * x).to(dtype)
|
| 309 |
+
F_x = self.unet(x_in, c_noise, class_labels, **unet_kwargs)
|
| 310 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 311 |
+
|
| 312 |
+
# Estimate uncertainty if requested.
|
| 313 |
+
if return_logvar:
|
| 314 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
| 315 |
+
return D_x, logvar # u(sigma) in Equation 21
|
| 316 |
+
return D_x
|
| 317 |
+
|
| 318 |
+
#----------------------------------------------------------------------------
|