Spaces:
Running
Running
added option for batch size to caching func
Browse files
msma.py
CHANGED
|
@@ -276,8 +276,16 @@ def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
|
| 276 |
|
| 277 |
|
| 278 |
@cmdline.command(name="cache-scores")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
@common_args
|
| 280 |
-
def cache_score_norms(preset, dataset_path, outdir):
|
| 281 |
device = DEVICE
|
| 282 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 283 |
refimg, reflabel = dsobj[0]
|
|
@@ -286,7 +294,7 @@ def cache_score_norms(preset, dataset_path, outdir):
|
|
| 286 |
f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
|
| 287 |
)
|
| 288 |
dsloader = torch.utils.data.DataLoader(
|
| 289 |
-
dsobj, batch_size=
|
| 290 |
)
|
| 291 |
|
| 292 |
model = build_model_from_pickle(preset=preset, device=device)
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
@cmdline.command(name="cache-scores")
|
| 279 |
+
@click.option(
|
| 280 |
+
"--batch_size",
|
| 281 |
+
help="Number of samples per batch",
|
| 282 |
+
metavar="INT",
|
| 283 |
+
type=int,
|
| 284 |
+
default=64,
|
| 285 |
+
show_default=True,
|
| 286 |
+
)
|
| 287 |
@common_args
|
| 288 |
+
def cache_score_norms(preset, dataset_path, outdir, batch_size):
|
| 289 |
device = DEVICE
|
| 290 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
| 291 |
refimg, reflabel = dsobj[0]
|
|
|
|
| 294 |
f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
|
| 295 |
)
|
| 296 |
dsloader = torch.utils.data.DataLoader(
|
| 297 |
+
dsobj, batch_size=batch_size, num_workers=4, prefetch_factor=2
|
| 298 |
)
|
| 299 |
|
| 300 |
model = build_model_from_pickle(preset=preset, device=device)
|