Spaces:
Running
Running
Added tensorboard logging
Browse files- flowutils.py +7 -0
- msma.py +42 -30
flowutils.py
CHANGED
|
@@ -25,6 +25,13 @@ def build_flows(
|
|
| 25 |
flows += [nf.flows.LULinearPermute(latent_size)]
|
| 26 |
|
| 27 |
# Set base distribution
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
q0 = nf.distributions.DiagGaussian(latent_size, trainable=True)
|
| 29 |
|
| 30 |
# Construct flow model
|
|
|
|
| 25 |
flows += [nf.flows.LULinearPermute(latent_size)]
|
| 26 |
|
| 27 |
# Set base distribution
|
| 28 |
+
|
| 29 |
+
# context_encoder = nn.Sequential([
|
| 30 |
+
# nn.Linear(context_size, context_size),
|
| 31 |
+
# nn.SiLU(),
|
| 32 |
+
# nn.Linear(context_size, context_size)
|
| 33 |
+
# ])
|
| 34 |
+
|
| 35 |
q0 = nf.distributions.DiagGaussian(latent_size, trainable=True)
|
| 36 |
|
| 37 |
# Construct flow model
|
msma.py
CHANGED
|
@@ -12,6 +12,7 @@ from sklearn.model_selection import GridSearchCV
|
|
| 12 |
from sklearn.pipeline import Pipeline
|
| 13 |
from sklearn.preprocessing import StandardScaler
|
| 14 |
from torch.utils.data import Subset
|
|
|
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
| 17 |
import dnnlib
|
|
@@ -213,7 +214,7 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
|
|
| 213 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
| 214 |
|
| 215 |
|
| 216 |
-
def train_flow(dataset_path, preset, outdir, device="cuda"):
|
| 217 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 218 |
refimg, reflabel = dsobj[0]
|
| 219 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
|
@@ -230,10 +231,10 @@ def train_flow(dataset_path, preset, outdir, device="cuda"):
|
|
| 230 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 231 |
|
| 232 |
trainiter = torch.utils.data.DataLoader(
|
| 233 |
-
train_ds, batch_size=
|
| 234 |
)
|
| 235 |
testiter = torch.utils.data.DataLoader(
|
| 236 |
-
val_ds, batch_size=
|
| 237 |
)
|
| 238 |
|
| 239 |
model = ScoreFlow(preset, device=device)
|
|
@@ -243,48 +244,59 @@ def train_flow(dataset_path, preset, outdir, device="cuda"):
|
|
| 243 |
flow_model=model.flow,
|
| 244 |
opt=opt,
|
| 245 |
train=True,
|
| 246 |
-
n_patches=
|
| 247 |
device=device,
|
| 248 |
)
|
| 249 |
eval_step = partial(
|
| 250 |
PatchFlow.stochastic_step,
|
| 251 |
flow_model=model.flow,
|
| 252 |
train=False,
|
| 253 |
-
n_patches=
|
| 254 |
device=device,
|
| 255 |
)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
| 268 |
|
| 269 |
-
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
|
| 272 |
|
| 273 |
-
|
|
|
|
| 274 |
val_loss = 0.0
|
| 275 |
-
|
| 276 |
-
x
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
@torch.inference_mode
|
|
|
|
| 12 |
from sklearn.pipeline import Pipeline
|
| 13 |
from sklearn.preprocessing import StandardScaler
|
| 14 |
from torch.utils.data import Subset
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
from tqdm import tqdm
|
| 17 |
|
| 18 |
import dnnlib
|
|
|
|
| 214 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
| 215 |
|
| 216 |
|
| 217 |
+
def train_flow(dataset_path, preset, outdir, epochs=10, device="cuda"):
|
| 218 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 219 |
refimg, reflabel = dsobj[0]
|
| 220 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
|
|
|
| 231 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
| 232 |
|
| 233 |
trainiter = torch.utils.data.DataLoader(
|
| 234 |
+
train_ds, batch_size=64, num_workers=4, prefetch_factor=2
|
| 235 |
)
|
| 236 |
testiter = torch.utils.data.DataLoader(
|
| 237 |
+
val_ds, batch_size=128, num_workers=4, prefetch_factor=2
|
| 238 |
)
|
| 239 |
|
| 240 |
model = ScoreFlow(preset, device=device)
|
|
|
|
| 244 |
flow_model=model.flow,
|
| 245 |
opt=opt,
|
| 246 |
train=True,
|
| 247 |
+
n_patches=128,
|
| 248 |
device=device,
|
| 249 |
)
|
| 250 |
eval_step = partial(
|
| 251 |
PatchFlow.stochastic_step,
|
| 252 |
flow_model=model.flow,
|
| 253 |
train=False,
|
| 254 |
+
n_patches=256,
|
| 255 |
device=device,
|
| 256 |
)
|
| 257 |
|
| 258 |
+
experiment_dir = f"{outdir}/{preset}"
|
| 259 |
+
os.makedirs(experiment_dir, exist_ok=True)
|
| 260 |
+
writer = SummaryWriter(f"{experiment_dir}/logs/")
|
| 261 |
|
| 262 |
+
# totaliters = int(epochs * train_len)
|
| 263 |
+
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 264 |
+
step = 0
|
| 265 |
|
| 266 |
+
for e in pbar:
|
| 267 |
+
for x, _ in trainiter:
|
| 268 |
+
x = x.to(device)
|
| 269 |
+
scores = model.scorenet(x)
|
| 270 |
|
| 271 |
+
if step == 0:
|
| 272 |
+
with torch.inference_mode():
|
| 273 |
+
val_loss = eval_step(scores, x)
|
| 274 |
|
| 275 |
+
train_loss = train_step(scores, x)
|
| 276 |
|
| 277 |
+
if (step + 1) % 10 == 0:
|
| 278 |
+
prev_val_loss = val_loss
|
| 279 |
val_loss = 0.0
|
| 280 |
+
with torch.inference_mode():
|
| 281 |
+
for i, (x, _) in enumerate(testiter):
|
| 282 |
+
x = x.to(device)
|
| 283 |
+
scores = model.scorenet(x)
|
| 284 |
+
val_loss += eval_step(scores, x)
|
| 285 |
+
break
|
| 286 |
+
val_loss /= i + 1
|
| 287 |
+
writer.add_scalar("loss/val", train_loss, step)
|
| 288 |
+
|
| 289 |
+
if val_loss < prev_val_loss:
|
| 290 |
+
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
| 291 |
+
|
| 292 |
+
writer.add_scalar("loss/train", train_loss, step)
|
| 293 |
+
pbar.set_description(
|
| 294 |
+
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
| 295 |
+
)
|
| 296 |
+
step += 1
|
| 297 |
+
|
| 298 |
+
# torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
| 299 |
+
writer.close()
|
| 300 |
|
| 301 |
|
| 302 |
@torch.inference_mode
|