Spaces:
Running
Running
minor fixes to train flow runner
Browse files
msma.py
CHANGED
|
@@ -189,7 +189,9 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
|
|
| 189 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 190 |
refimg, reflabel = dsobj[0]
|
| 191 |
print(f"Loading dataset from {dataset_path}")
|
| 192 |
-
print(
|
|
|
|
|
|
|
| 193 |
dsloader = torch.utils.data.DataLoader(
|
| 194 |
dsobj, batch_size=48, num_workers=4, prefetch_factor=2
|
| 195 |
)
|
|
@@ -211,7 +213,7 @@ def cache_score_norms(preset, dataset_path, outdir, device="cpu"):
|
|
| 211 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
| 212 |
|
| 213 |
|
| 214 |
-
def train_flow(dataset_path, preset, device="cuda"):
|
| 215 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 216 |
refimg, reflabel = dsobj[0]
|
| 217 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
|
@@ -252,6 +254,7 @@ def train_flow(dataset_path, preset, device="cuda"):
|
|
| 252 |
device=device,
|
| 253 |
)
|
| 254 |
|
|
|
|
| 255 |
pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
|
| 256 |
step = 0
|
| 257 |
|
|
@@ -280,8 +283,8 @@ def train_flow(dataset_path, preset, device="cuda"):
|
|
| 280 |
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
| 281 |
)
|
| 282 |
step += 1
|
| 283 |
-
|
| 284 |
-
torch.save(model.flow.state_dict(), f"
|
| 285 |
|
| 286 |
|
| 287 |
@torch.inference_mode
|
|
@@ -327,22 +330,43 @@ def test_flow_runner(preset, device="cpu", load_weights=None):
|
|
| 327 |
@click.command()
|
| 328 |
|
| 329 |
# Main options.
|
| 330 |
-
@click.option(
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
)
|
| 333 |
-
@click.option('--outdir', help='Where to load/save the results', metavar='DIR', type=str, required=True)
|
| 334 |
-
@click.option('--preset', help='Configuration preset', metavar='STR', type=str, default='edm2-img64-s-fid', show_default=True)
|
| 335 |
-
@click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, default=None)
|
| 336 |
def cmdline(run, outdir, **opts):
|
| 337 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 338 |
-
preset = opts[
|
| 339 |
-
dataset_path = opts[
|
| 340 |
-
|
| 341 |
-
if run in ['cache-scores', 'train-flow']:
|
| 342 |
-
assert opts['data'] is not None, "Provide path to dataset"
|
| 343 |
|
|
|
|
|
|
|
|
|
|
| 344 |
if run == "cache-scores":
|
| 345 |
-
cache_score_norms(
|
|
|
|
|
|
|
| 346 |
|
| 347 |
if run == "train-gmm":
|
| 348 |
train_gmm(
|
|
@@ -350,8 +374,11 @@ def cmdline(run, outdir, **opts):
|
|
| 350 |
outdir=f"{outdir}/{preset}",
|
| 351 |
grid_search=True,
|
| 352 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
# test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
|
| 355 |
# train_flow(imagenette_path, preset, device)
|
| 356 |
|
| 357 |
# cache_score_norms(
|
|
@@ -368,5 +395,6 @@ def cmdline(run, outdir, **opts):
|
|
| 368 |
# nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
|
| 369 |
# print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|
| 370 |
|
|
|
|
| 371 |
if __name__ == "__main__":
|
| 372 |
cmdline()
|
|
|
|
| 189 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 190 |
refimg, reflabel = dsobj[0]
|
| 191 |
print(f"Loading dataset from {dataset_path}")
|
| 192 |
+
print(
|
| 193 |
+
f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
|
| 194 |
+
)
|
| 195 |
dsloader = torch.utils.data.DataLoader(
|
| 196 |
dsobj, batch_size=48, num_workers=4, prefetch_factor=2
|
| 197 |
)
|
|
|
|
| 213 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
| 214 |
|
| 215 |
|
| 216 |
+
def train_flow(dataset_path, preset, outdir, device="cuda"):
|
| 217 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 218 |
refimg, reflabel = dsobj[0]
|
| 219 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
|
|
|
| 254 |
device=device,
|
| 255 |
)
|
| 256 |
|
| 257 |
+
os.makedirs(f"{outdir}/{preset}", exist_ok=True)
|
| 258 |
pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
|
| 259 |
step = 0
|
| 260 |
|
|
|
|
| 283 |
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
| 284 |
)
|
| 285 |
step += 1
|
| 286 |
+
|
| 287 |
+
torch.save(model.flow.state_dict(), f"{outdir}/{preset}/flow.pt")
|
| 288 |
|
| 289 |
|
| 290 |
@torch.inference_mode
|
|
|
|
| 330 |
@click.command()
|
| 331 |
|
| 332 |
# Main options.
|
| 333 |
+
@click.option(
|
| 334 |
+
"--run",
|
| 335 |
+
help="Which function to run",
|
| 336 |
+
type=click.Choice(
|
| 337 |
+
["cache-scores", "train-flow", "train-gmm"], case_sensitive=False
|
| 338 |
+
),
|
| 339 |
+
)
|
| 340 |
+
@click.option(
|
| 341 |
+
"--outdir",
|
| 342 |
+
help="Where to load/save the results",
|
| 343 |
+
metavar="DIR",
|
| 344 |
+
type=str,
|
| 345 |
+
required=True,
|
| 346 |
+
)
|
| 347 |
+
@click.option(
|
| 348 |
+
"--preset",
|
| 349 |
+
help="Configuration preset",
|
| 350 |
+
metavar="STR",
|
| 351 |
+
type=str,
|
| 352 |
+
default="edm2-img64-s-fid",
|
| 353 |
+
show_default=True,
|
| 354 |
+
)
|
| 355 |
+
@click.option(
|
| 356 |
+
"--data", help="Path to the dataset", metavar="ZIP|DIR", type=str, default=None
|
| 357 |
)
|
|
|
|
|
|
|
|
|
|
| 358 |
def cmdline(run, outdir, **opts):
|
| 359 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 360 |
+
preset = opts["preset"]
|
| 361 |
+
dataset_path = opts["data"]
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
+
if run in ["cache-scores", "train-flow"]:
|
| 364 |
+
assert opts["data"] is not None, "Provide path to dataset"
|
| 365 |
+
|
| 366 |
if run == "cache-scores":
|
| 367 |
+
cache_score_norms(
|
| 368 |
+
preset=preset, dataset_path=dataset_path, outdir=outdir, device=device
|
| 369 |
+
)
|
| 370 |
|
| 371 |
if run == "train-gmm":
|
| 372 |
train_gmm(
|
|
|
|
| 374 |
outdir=f"{outdir}/{preset}",
|
| 375 |
grid_search=True,
|
| 376 |
)
|
| 377 |
+
|
| 378 |
+
if run == "train-flow":
|
| 379 |
+
train_flow(dataset_path, outdir=outdir, preset=preset, device=device)
|
| 380 |
+
test_flow_runner(preset, device=device, load_weights=f"{outdir}/{preset}/flow.pt")
|
| 381 |
|
|
|
|
| 382 |
# train_flow(imagenette_path, preset, device)
|
| 383 |
|
| 384 |
# cache_score_norms(
|
|
|
|
| 395 |
# nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
|
| 396 |
# print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|
| 397 |
|
| 398 |
+
|
| 399 |
if __name__ == "__main__":
|
| 400 |
cmdline()
|