Spaces:
Running
Running
saving model configs
Browse files- flowutils.py +1 -1
- msma.py +25 -17
flowutils.py
CHANGED
|
@@ -221,7 +221,7 @@ class PatchFlow(torch.nn.Module):
|
|
| 221 |
):
|
| 222 |
super().__init__()
|
| 223 |
|
| 224 |
-
self.config = sanitize_locals(locals(), ignore_keys=input_size)
|
| 225 |
|
| 226 |
num_sigmas, c, h, w = input_size
|
| 227 |
self.local_pooler = SpatialNormer(
|
|
|
|
| 221 |
):
|
| 222 |
super().__init__()
|
| 223 |
|
| 224 |
+
self.config = sanitize_locals(locals(), ignore_keys="input_size")
|
| 225 |
|
| 226 |
num_sigmas, c, h, w = input_size
|
| 227 |
self.local_pooler = SpatialNormer(
|
msma.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import datetime
|
|
|
|
| 2 |
import os
|
| 3 |
import pickle
|
| 4 |
from functools import partial
|
|
@@ -21,7 +22,7 @@ import dnnlib
|
|
| 21 |
from dataset import ImageFolderDataset
|
| 22 |
from flowutils import PatchFlow, sanitize_locals
|
| 23 |
|
| 24 |
-
DEVICE: Literal["cuda", "cpu"] =
|
| 25 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
| 26 |
|
| 27 |
config_presets = {
|
|
@@ -56,8 +57,8 @@ class EDMScorer(torch.nn.Module):
|
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
|
| 59 |
-
self.config = sanitize_locals(locals(), ignore_keys=
|
| 60 |
-
self.config[
|
| 61 |
|
| 62 |
self.use_fp16 = use_fp16
|
| 63 |
self.sigma_min = sigma_min
|
|
@@ -100,19 +101,13 @@ class EDMScorer(torch.nn.Module):
|
|
| 100 |
|
| 101 |
|
| 102 |
class ScoreFlow(torch.nn.Module):
|
| 103 |
-
def __init__(
|
| 104 |
-
self,
|
| 105 |
-
scorenet,
|
| 106 |
-
device="cpu",
|
| 107 |
-
**flow_kwargs
|
| 108 |
-
):
|
| 109 |
super().__init__()
|
| 110 |
|
| 111 |
h = w = scorenet.net.img_resolution
|
| 112 |
c = scorenet.net.img_channels
|
| 113 |
num_sigmas = len(scorenet.sigma_steps)
|
| 114 |
self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
|
| 115 |
-
|
| 116 |
|
| 117 |
self.flow = self.flow.to(device)
|
| 118 |
self.scorenet = scorenet.to(device).requires_grad_(False)
|
|
@@ -265,7 +260,6 @@ def cmdline():
|
|
| 265 |
type=str,
|
| 266 |
required=True,
|
| 267 |
)
|
| 268 |
-
|
| 269 |
def cache_score_norms(preset, dataset_path, outdir):
|
| 270 |
device = DEVICE
|
| 271 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
|
@@ -353,7 +347,7 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 353 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 354 |
|
| 355 |
trainiter = torch.utils.data.DataLoader(
|
| 356 |
-
train_ds, batch_size=64, num_workers=4, prefetch_factor=2
|
| 357 |
)
|
| 358 |
testiter = torch.utils.data.DataLoader(
|
| 359 |
val_ds, batch_size=128, num_workers=4, prefetch_factor=2
|
|
@@ -383,6 +377,9 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 383 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
| 384 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
| 385 |
|
|
|
|
|
|
|
|
|
|
| 386 |
# totaliters = int(epochs * train_len)
|
| 387 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 388 |
step = 0
|
|
@@ -398,8 +395,17 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 398 |
val_loss = eval_step(scores, x)
|
| 399 |
|
| 400 |
# Log details about model
|
| 401 |
-
writer.add_graph(
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
train_loss = train_step(scores, x)
|
| 405 |
|
|
@@ -433,12 +439,14 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 433 |
scores = model.scorenet(x)
|
| 434 |
train_loss = train_step(scores, x)
|
| 435 |
writer.add_scalar("loss/train", train_loss, step)
|
| 436 |
-
pbar.set_description(
|
| 437 |
-
f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}"
|
| 438 |
-
)
|
| 439 |
step += 1
|
| 440 |
|
|
|
|
| 441 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
|
|
|
|
|
|
|
|
|
| 442 |
writer.close()
|
| 443 |
|
| 444 |
|
|
|
|
| 1 |
import datetime
|
| 2 |
+
import json
|
| 3 |
import os
|
| 4 |
import pickle
|
| 5 |
from functools import partial
|
|
|
|
| 22 |
from dataset import ImageFolderDataset
|
| 23 |
from flowutils import PatchFlow, sanitize_locals
|
| 24 |
|
| 25 |
+
DEVICE: Literal["cuda", "cpu"] = "cpu"
|
| 26 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
| 27 |
|
| 28 |
config_presets = {
|
|
|
|
| 57 |
):
|
| 58 |
super().__init__()
|
| 59 |
|
| 60 |
+
self.config = sanitize_locals(locals(), ignore_keys="net")
|
| 61 |
+
self.config["EDMNet"] = dict(net.init_kwargs)
|
| 62 |
|
| 63 |
self.use_fp16 = use_fp16
|
| 64 |
self.sigma_min = sigma_min
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
class ScoreFlow(torch.nn.Module):
|
| 104 |
+
def __init__(self, scorenet, device="cpu", **flow_kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
super().__init__()
|
| 106 |
|
| 107 |
h = w = scorenet.net.img_resolution
|
| 108 |
c = scorenet.net.img_channels
|
| 109 |
num_sigmas = len(scorenet.sigma_steps)
|
| 110 |
self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs)
|
|
|
|
| 111 |
|
| 112 |
self.flow = self.flow.to(device)
|
| 113 |
self.scorenet = scorenet.to(device).requires_grad_(False)
|
|
|
|
| 260 |
type=str,
|
| 261 |
required=True,
|
| 262 |
)
|
|
|
|
| 263 |
def cache_score_norms(preset, dataset_path, outdir):
|
| 264 |
device = DEVICE
|
| 265 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
|
|
|
| 347 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 348 |
|
| 349 |
trainiter = torch.utils.data.DataLoader(
|
| 350 |
+
train_ds, batch_size=64, num_workers=4, prefetch_factor=2, shuffle=True
|
| 351 |
)
|
| 352 |
testiter = torch.utils.data.DataLoader(
|
| 353 |
val_ds, batch_size=128, num_workers=4, prefetch_factor=2
|
|
|
|
| 377 |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
|
| 378 |
writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}")
|
| 379 |
|
| 380 |
+
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
| 381 |
+
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 382 |
+
|
| 383 |
# totaliters = int(epochs * train_len)
|
| 384 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 385 |
step = 0
|
|
|
|
| 395 |
val_loss = eval_step(scores, x)
|
| 396 |
|
| 397 |
# Log details about model
|
| 398 |
+
writer.add_graph(
|
| 399 |
+
model.flow.flows,
|
| 400 |
+
(
|
| 401 |
+
torch.zeros(1, scores.shape[1], device=device),
|
| 402 |
+
torch.zeros(
|
| 403 |
+
1,
|
| 404 |
+
model.flow.position_encoding.cached_penc.shape[-1],
|
| 405 |
+
device=device,
|
| 406 |
+
),
|
| 407 |
+
),
|
| 408 |
+
)
|
| 409 |
|
| 410 |
train_loss = train_step(scores, x)
|
| 411 |
|
|
|
|
| 439 |
scores = model.scorenet(x)
|
| 440 |
train_loss = train_step(scores, x)
|
| 441 |
writer.add_scalar("loss/train", train_loss, step)
|
| 442 |
+
pbar.set_description(f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}")
|
|
|
|
|
|
|
| 443 |
step += 1
|
| 444 |
|
| 445 |
+
# Save final model
|
| 446 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
| 447 |
+
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 448 |
+
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 449 |
+
|
| 450 |
writer.close()
|
| 451 |
|
| 452 |
|