Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import threading | |
| from torch._utils import ExceptionWrapper | |
| import logging | |
| import torch.nn.functional as F | |
| def get_a_var(obj): | |
| if isinstance(obj, torch.Tensor): | |
| return obj | |
| if isinstance(obj, list) or isinstance(obj, tuple): | |
| for result in map(get_a_var, obj): | |
| if isinstance(result, torch.Tensor): | |
| return result | |
| if isinstance(obj, dict): | |
| for result in map(get_a_var, obj.items()): | |
| if isinstance(result, torch.Tensor): | |
| return result | |
| return None | |
| def parallel_apply(fct, model, inputs, device_ids): | |
| modules = nn.parallel.replicate(model, device_ids) | |
| assert len(modules) == len(inputs) | |
| lock = threading.Lock() | |
| results = {} | |
| grad_enabled = torch.is_grad_enabled() | |
| def _worker(i, module, input): | |
| torch.set_grad_enabled(grad_enabled) | |
| device = get_a_var(input).get_device() | |
| try: | |
| with torch.cuda.device(device): | |
| # this also avoids accidental slicing of `input` if it is a Tensor | |
| if not isinstance(input, (list, tuple)): | |
| input = (input,) | |
| output = fct(module, *input) | |
| with lock: | |
| results[i] = output | |
| except Exception: | |
| with lock: | |
| results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) | |
| if len(modules) > 1: | |
| threads = [threading.Thread(target=_worker, args=(i, module, input)) | |
| for i, (module, input) in enumerate(zip(modules, inputs))] | |
| for thread in threads: | |
| thread.start() | |
| for thread in threads: | |
| thread.join() | |
| else: | |
| _worker(0, modules[0], inputs[0]) | |
| outputs = [] | |
| for i in range(len(inputs)): | |
| output = results[i] | |
| if isinstance(output, ExceptionWrapper): | |
| output.reraise() | |
| outputs.append(output) | |
| return outputs | |
| def get_logger(filename=None): | |
| logger = logging.getLogger('logger') | |
| logger.setLevel(logging.DEBUG) | |
| logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%m/%d/%Y %H:%M:%S', | |
| level=logging.INFO) | |
| if filename is not None: | |
| handler = logging.FileHandler(filename) | |
| handler.setLevel(logging.DEBUG) | |
| handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) | |
| logging.getLogger().addHandler(handler) | |
| return logger |