Spaces:
Sleeping
Sleeping
fixed gridsearch arg
Browse files
msma.py
CHANGED
|
@@ -223,16 +223,23 @@ def common_args(func):
|
|
| 223 |
return wrapper
|
| 224 |
|
| 225 |
@cmdline.command('train-gmm')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
@common_args
|
| 227 |
-
def train_gmm(
|
| 228 |
-
|
|
|
|
| 229 |
|
| 230 |
gm = GaussianMixture(
|
| 231 |
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
| 232 |
)
|
| 233 |
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
|
| 234 |
|
| 235 |
-
if
|
| 236 |
param_grid = dict(
|
| 237 |
GMM__n_components=range(2, 11, 1),
|
| 238 |
)
|
|
@@ -369,6 +376,9 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 369 |
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
| 370 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 371 |
|
|
|
|
|
|
|
|
|
|
| 372 |
# totaliters = int(epochs * train_len)
|
| 373 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 374 |
step = 0
|
|
@@ -433,8 +443,6 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
| 433 |
|
| 434 |
# Save final model
|
| 435 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
| 436 |
-
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 437 |
-
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 438 |
|
| 439 |
writer.close()
|
| 440 |
|
|
|
|
| 223 |
return wrapper
|
| 224 |
|
| 225 |
@cmdline.command('train-gmm')
|
| 226 |
+
@click.option(
|
| 227 |
+
"--gridsearch",
|
| 228 |
+
help="Whether to use a grid search on a number of components to find the best fit",
|
| 229 |
+
is_flag=True,
|
| 230 |
+
default=False,
|
| 231 |
+
)
|
| 232 |
@common_args
|
| 233 |
+
def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
| 234 |
+
score_path = f"{outdir}/{preset}/imagenette_score_norms.pt"
|
| 235 |
+
X = torch.load(score_path).numpy()
|
| 236 |
|
| 237 |
gm = GaussianMixture(
|
| 238 |
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
| 239 |
)
|
| 240 |
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
|
| 241 |
|
| 242 |
+
if gridsearch:
|
| 243 |
param_grid = dict(
|
| 244 |
GMM__n_components=range(2, 11, 1),
|
| 245 |
)
|
|
|
|
| 376 |
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
| 377 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 378 |
|
| 379 |
+
with open(f"{experiment_dir}/config.json", "w") as f:
|
| 380 |
+
json.dump(model.config, f, sort_keys=True, indent=4)
|
| 381 |
+
|
| 382 |
# totaliters = int(epochs * train_len)
|
| 383 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
| 384 |
step = 0
|
|
|
|
| 443 |
|
| 444 |
# Save final model
|
| 445 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
|
|
|
|
|
|
| 446 |
|
| 447 |
writer.close()
|
| 448 |
|