MSRNet / utils /io /params.py
linaa98's picture
Update utils/io/params.py
b7c3fbd verified
raw
history blame contribute delete
950 Bytes
# -*- coding: utf-8 -*-
#Author: Lart Pang (https://github.com/lartpang)
import os
import torch
def save_weight(save_path, model):
print(f"Saving weight '{save_path}'")
if isinstance(model, dict):
model_state = model
else:
model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
torch.save(model_state, save_path)
print(f"Saved weight '{save_path}' " f"(only contain the net's weight)")
def load_weight(load_path, model, *, strict=True, skip_unmatched_shape=False):
assert os.path.exists(load_path), load_path
model_params = model.state_dict()
for k, v in torch.load(load_path, map_location="cpu").items():
if k.endswith("module."):
k = k[7:]
if skip_unmatched_shape and k in model_params and v.shape != model_params[k].shape:
continue
model_params[k] = v
model.load_state_dict(model_params, strict=strict)