Spaces:
Sleeping
Sleeping
| # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is licensed under a Creative Commons | |
| # Attribution-NonCommercial-ShareAlike 4.0 International License. | |
| # You should have received a copy of the license along with this | |
| # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ | |
| import os | |
| import re | |
| import socket | |
| import torch | |
| import torch.distributed | |
| from . import training_stats | |
| _sync_device = None | |
| #---------------------------------------------------------------------------- | |
| def init(): | |
| global _sync_device | |
| if not torch.distributed.is_initialized(): | |
| # Setup some reasonable defaults for env-based distributed init if | |
| # not set by the running environment. | |
| if 'MASTER_ADDR' not in os.environ: | |
| os.environ['MASTER_ADDR'] = 'localhost' | |
| if 'MASTER_PORT' not in os.environ: | |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| s.bind(('', 0)) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| os.environ['MASTER_PORT'] = str(s.getsockname()[1]) | |
| s.close() | |
| if 'RANK' not in os.environ: | |
| os.environ['RANK'] = '0' | |
| if 'LOCAL_RANK' not in os.environ: | |
| os.environ['LOCAL_RANK'] = '0' | |
| if 'WORLD_SIZE' not in os.environ: | |
| os.environ['WORLD_SIZE'] = '1' | |
| backend = 'gloo' if os.name == 'nt' else 'nccl' | |
| torch.distributed.init_process_group(backend=backend, init_method='env://') | |
| torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) | |
| _sync_device = torch.device('cuda') if get_world_size() > 1 else None | |
| training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device) | |
| #---------------------------------------------------------------------------- | |
| def get_rank(): | |
| return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 | |
| #---------------------------------------------------------------------------- | |
| def get_world_size(): | |
| return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 | |
| #---------------------------------------------------------------------------- | |
| def should_stop(): | |
| return False | |
| #---------------------------------------------------------------------------- | |
| def should_suspend(): | |
| return False | |
| #---------------------------------------------------------------------------- | |
| def request_suspend(): | |
| pass | |
| #---------------------------------------------------------------------------- | |
| def update_progress(cur, total): | |
| pass | |
| #---------------------------------------------------------------------------- | |
| def print0(*args, **kwargs): | |
| if get_rank() == 0: | |
| print(*args, **kwargs) | |
| #---------------------------------------------------------------------------- | |
| class CheckpointIO: | |
| def __init__(self, **kwargs): | |
| self._state_objs = kwargs | |
| def save(self, pt_path, verbose=True): | |
| if verbose: | |
| print0(f'Saving {pt_path} ... ', end='', flush=True) | |
| data = dict() | |
| for name, obj in self._state_objs.items(): | |
| if obj is None: | |
| data[name] = None | |
| elif isinstance(obj, dict): | |
| data[name] = obj | |
| elif hasattr(obj, 'state_dict'): | |
| data[name] = obj.state_dict() | |
| elif hasattr(obj, '__getstate__'): | |
| data[name] = obj.__getstate__() | |
| elif hasattr(obj, '__dict__'): | |
| data[name] = obj.__dict__ | |
| else: | |
| raise ValueError(f'Invalid state object of type {type(obj).__name__}') | |
| if get_rank() == 0: | |
| torch.save(data, pt_path) | |
| if verbose: | |
| print0('done') | |
| def load(self, pt_path, verbose=True): | |
| if verbose: | |
| print0(f'Loading {pt_path} ... ', end='', flush=True) | |
| data = torch.load(pt_path, map_location=torch.device('cpu')) | |
| for name, obj in self._state_objs.items(): | |
| if obj is None: | |
| pass | |
| elif isinstance(obj, dict): | |
| obj.clear() | |
| obj.update(data[name]) | |
| elif hasattr(obj, 'load_state_dict'): | |
| obj.load_state_dict(data[name]) | |
| elif hasattr(obj, '__setstate__'): | |
| obj.__setstate__(data[name]) | |
| elif hasattr(obj, '__dict__'): | |
| obj.__dict__.clear() | |
| obj.__dict__.update(data[name]) | |
| else: | |
| raise ValueError(f'Invalid state object of type {type(obj).__name__}') | |
| if verbose: | |
| print0('done') | |
| def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True): | |
| fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)] | |
| if len(fnames) == 0: | |
| return None | |
| pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1)))) | |
| self.load(pt_path, verbose=verbose) | |
| return pt_path | |
| #---------------------------------------------------------------------------- | |