Spaces:
Running
Running
uploading gmm and score norms
Browse files- push_to_hf.py +12 -6
push_to_hf.py
CHANGED
|
@@ -26,9 +26,13 @@ def main(basedir, preset):
|
|
| 26 |
modeldir = basedir / preset
|
| 27 |
|
| 28 |
net = build_model_from_pickle(preset)
|
|
|
|
|
|
|
|
|
|
| 29 |
model = ScoreFlow(
|
| 30 |
net,
|
| 31 |
-
|
|
|
|
| 32 |
)
|
| 33 |
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
|
| 34 |
|
|
@@ -46,10 +50,13 @@ def main(basedir, preset):
|
|
| 46 |
save_file(model.state_dict(), tmpdir / "model.safetensors")
|
| 47 |
|
| 48 |
# save config
|
| 49 |
-
(tmpdir / "config.json").write_text(
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Generate model card
|
| 55 |
# card = generate_model_card(model)
|
|
@@ -57,8 +64,7 @@ def main(basedir, preset):
|
|
| 57 |
|
| 58 |
# Save logs
|
| 59 |
shutil.copytree(modeldir / "logs", tmpdir / "logs")
|
| 60 |
-
|
| 61 |
-
|
| 62 |
# Push to hub
|
| 63 |
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
|
| 64 |
|
|
|
|
| 26 |
modeldir = basedir / preset
|
| 27 |
|
| 28 |
net = build_model_from_pickle(preset)
|
| 29 |
+
with open(modeldir / "config.json", "rb") as f:
|
| 30 |
+
model_params = json.load(f)
|
| 31 |
+
|
| 32 |
model = ScoreFlow(
|
| 33 |
net,
|
| 34 |
+
device="cpu",
|
| 35 |
+
**model_params["PatchFlow"],
|
| 36 |
)
|
| 37 |
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
|
| 38 |
|
|
|
|
| 50 |
save_file(model.state_dict(), tmpdir / "model.safetensors")
|
| 51 |
|
| 52 |
# save config
|
| 53 |
+
(tmpdir / "config.json").write_text(
|
| 54 |
+
json.dumps(model.config, sort_keys=True, indent=4)
|
| 55 |
+
)
|
| 56 |
|
| 57 |
+
# save gmm and cached score norms
|
| 58 |
+
shutil.copyfile(modeldir / "gmm.pkl", tmpdir / "gmm.pkl")
|
| 59 |
+
shutil.copyfile(modeldir / "refscores.npz", tmpdir / "refscores.npz")
|
| 60 |
|
| 61 |
# Generate model card
|
| 62 |
# card = generate_model_card(model)
|
|
|
|
| 64 |
|
| 65 |
# Save logs
|
| 66 |
shutil.copytree(modeldir / "logs", tmpdir / "logs")
|
| 67 |
+
|
|
|
|
| 68 |
# Push to hub
|
| 69 |
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
|
| 70 |
|