# -*- coding: utf-8 -*- #Author: Lart Pang (https://github.com/lartpang) import os import torch def save_weight(save_path, model): print(f"Saving weight '{save_path}'") if isinstance(model, dict): model_state = model else: model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict() torch.save(model_state, save_path) print(f"Saved weight '{save_path}' " f"(only contain the net's weight)") def load_weight(load_path, model, *, strict=True, skip_unmatched_shape=False): assert os.path.exists(load_path), load_path model_params = model.state_dict() for k, v in torch.load(load_path, map_location="cpu").items(): if k.endswith("module."): k = k[7:] if skip_unmatched_shape and k in model_params and v.shape != model_params[k].shape: continue model_params[k] = v model.load_state_dict(model_params, strict=strict)