Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import json | |
| import safetensors.torch | |
| import backend.misc.checkpoint_pickle | |
| def read_arbitrary_config(directory): | |
| config_path = os.path.join(directory, 'config.json') | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"No config.json file found in the directory: {directory}") | |
| with open(config_path, 'rt', encoding='utf-8') as file: | |
| config_data = json.load(file) | |
| return config_data | |
| def load_torch_file(ckpt, safe_load=False, device=None): | |
| if device is None: | |
| device = torch.device("cpu") | |
| if ckpt.lower().endswith(".safetensors"): | |
| sd = safetensors.torch.load_file(ckpt, device=device.type) | |
| else: | |
| if safe_load: | |
| if not 'weights_only' in torch.load.__code__.co_varnames: | |
| print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") | |
| safe_load = False | |
| if safe_load: | |
| pl_sd = torch.load(ckpt, map_location=device, weights_only=True) | |
| else: | |
| pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle) | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| if "state_dict" in pl_sd: | |
| sd = pl_sd["state_dict"] | |
| else: | |
| sd = pl_sd | |
| return sd | |
| def set_attr(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) | |
| def set_attr_raw(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| setattr(obj, attrs[-1], value) | |
| def copy_to_param(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| prev = getattr(obj, attrs[-1]) | |
| prev.data.copy_(value) | |
| def get_attr(obj, attr): | |
| attrs = attr.split(".") | |
| for name in attrs: | |
| obj = getattr(obj, name) | |
| return obj | |
| def calculate_parameters(sd, prefix=""): | |
| params = 0 | |
| for k in sd.keys(): | |
| if k.startswith(prefix): | |
| params += sd[k].nelement() | |
| return params | |