Spaces:
Runtime error
Runtime error
| # Copyright Forge 2024 | |
| import time | |
| import torch | |
| import contextlib | |
| from backend import stream, memory_management | |
| stash = {} | |
| def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False): | |
| weight, bias, signal = None, None, None | |
| non_blocking = True | |
| if getattr(x.device, 'type', None) == 'mps': | |
| non_blocking = False | |
| target_dtype = x.dtype | |
| target_device = x.device | |
| if skip_weight_dtype: | |
| weight_args = dict(device=target_device, non_blocking=non_blocking) | |
| else: | |
| weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) | |
| if skip_bias_dtype: | |
| bias_args = dict(device=target_device, non_blocking=non_blocking) | |
| else: | |
| bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) | |
| if stream.should_use_stream(): | |
| with stream.stream_context()(stream.mover_stream): | |
| if layer.weight is not None: | |
| weight = layer.weight.to(**weight_args) | |
| if layer.bias is not None: | |
| bias = layer.bias.to(**bias_args) | |
| signal = stream.mover_stream.record_event() | |
| else: | |
| if layer.weight is not None: | |
| weight = layer.weight.to(**weight_args) | |
| if layer.bias is not None: | |
| bias = layer.bias.to(**bias_args) | |
| return weight, bias, signal | |
| def main_stream_worker(weight, bias, signal): | |
| if signal is None or not stream.should_use_stream(): | |
| yield | |
| return | |
| with stream.stream_context()(stream.current_stream): | |
| stream.current_stream.wait_event(signal) | |
| yield | |
| finished_signal = stream.current_stream.record_event() | |
| stash[id(finished_signal)] = (weight, bias, finished_signal) | |
| garbage = [] | |
| for k, (w, b, s) in stash.items(): | |
| if s.query(): | |
| garbage.append(k) | |
| for k in garbage: | |
| del stash[k] | |
| return | |
| def cleanup_cache(): | |
| if not stream.should_use_stream(): | |
| return | |
| stream.current_stream.synchronize() | |
| stream.mover_stream.synchronize() | |
| stash.clear() | |
| return | |
| current_device = None | |
| current_dtype = None | |
| current_manual_cast_enabled = False | |
| current_bnb_dtype = None | |
| class ForgeOperations: | |
| class Linear(torch.nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) | |
| self.weight = None | |
| self.bias = None | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
| if hasattr(self, 'dummy'): | |
| if prefix + 'weight' in state_dict: | |
| self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy)) | |
| if prefix + 'bias' in state_dict: | |
| self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) | |
| del self.dummy | |
| else: | |
| super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.linear(x, weight, bias) | |
| else: | |
| return torch.nn.functional.linear(x, self.weight, self.bias) | |
| class Conv2d(torch.nn.Conv2d): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return self._conv_forward(x, weight, bias) | |
| else: | |
| return super().forward(x) | |
| class Conv3d(torch.nn.Conv3d): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return self._conv_forward(x, weight, bias) | |
| else: | |
| return super().forward(x) | |
| class Conv1d(torch.nn.Conv1d): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return self._conv_forward(x, weight, bias) | |
| else: | |
| return super().forward(x) | |
| class ConvTranspose2d(torch.nn.ConvTranspose2d): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x, output_size=None): | |
| if self.parameters_manual_cast: | |
| num_spatial_dims = 2 | |
| output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
| else: | |
| return super().forward(x, output_size) | |
| class ConvTranspose1d(torch.nn.ConvTranspose1d): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x, output_size=None): | |
| if self.parameters_manual_cast: | |
| num_spatial_dims = 1 | |
| output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
| else: | |
| return super().forward(x, output_size) | |
| class ConvTranspose3d(torch.nn.ConvTranspose3d): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x, output_size=None): | |
| if self.parameters_manual_cast: | |
| num_spatial_dims = 3 | |
| output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
| else: | |
| return super().forward(x, output_size) | |
| class GroupNorm(torch.nn.GroupNorm): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.group_norm(x, self.num_groups, weight, bias, self.eps) | |
| else: | |
| return super().forward(x) | |
| class LayerNorm(torch.nn.LayerNorm): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| kwargs['dtype'] = current_dtype | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def reset_parameters(self): | |
| return None | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) | |
| else: | |
| return super().forward(x) | |
| class Embedding(torch.nn.Embedding): | |
| def __init__(self, *args, **kwargs): | |
| kwargs['device'] = current_device | |
| super().__init__(*args, **kwargs) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| self.bias = None | |
| def reset_parameters(self): | |
| self.bias = None | |
| return None | |
| def forward(self, x): | |
| if self.parameters_manual_cast: | |
| weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) | |
| with main_stream_worker(weight, bias, signal): | |
| return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) | |
| else: | |
| return super().forward(x) | |
| try: | |
| from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits | |
| class ForgeOperationsBNB4bits(ForgeOperations): | |
| class Linear(ForgeLoader4Bit): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype) | |
| self.parameters_manual_cast = current_manual_cast_enabled | |
| def forward(self, x): | |
| self.weight.quant_state = self.quant_state | |
| if self.bias is not None and self.bias.dtype != x.dtype: | |
| # Maybe this can also be set to all non-bnb ops since the cost is very low. | |
| # And it only invokes one time, and most linear does not have bias | |
| self.bias.data = self.bias.data.to(x.dtype) | |
| if not self.parameters_manual_cast: | |
| return functional_linear_4bits(x, self.weight, self.bias) | |
| elif not self.weight.bnb_quantized: | |
| assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' | |
| layer_original_device = self.weight.device | |
| self.weight = self.weight._quantize(x.device) | |
| bias = self.bias.to(x.device) if self.bias is not None else None | |
| out = functional_linear_4bits(x, self.weight, bias) | |
| self.weight = self.weight.to(layer_original_device) | |
| return out | |
| else: | |
| weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) | |
| with main_stream_worker(weight, bias, signal): | |
| return functional_linear_4bits(x, weight, bias) | |
| bnb_avaliable = True | |
| except: | |
| bnb_avaliable = False | |
| def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None): | |
| global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype | |
| current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype | |
| if operations is None: | |
| if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']: | |
| operations = ForgeOperationsBNB4bits | |
| else: | |
| operations = ForgeOperations | |
| op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding'] | |
| backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} | |
| try: | |
| for op_name in op_names: | |
| setattr(torch.nn, op_name, getattr(operations, op_name)) | |
| yield | |
| finally: | |
| for op_name in op_names: | |
| setattr(torch.nn, op_name, backups[op_name]) | |
| return | |
| def shift_manual_cast(model, enabled): | |
| for m in model.modules(): | |
| if hasattr(m, 'parameters_manual_cast'): | |
| m.parameters_manual_cast = enabled | |
| return | |
| def automatic_memory_management(): | |
| memory_management.free_memory( | |
| memory_required=3 * 1024 * 1024 * 1024, | |
| device=memory_management.get_torch_device() | |
| ) | |
| module_list = [] | |
| original_init = torch.nn.Module.__init__ | |
| original_to = torch.nn.Module.to | |
| def patched_init(self, *args, **kwargs): | |
| module_list.append(self) | |
| return original_init(self, *args, **kwargs) | |
| def patched_to(self, *args, **kwargs): | |
| module_list.append(self) | |
| return original_to(self, *args, **kwargs) | |
| try: | |
| torch.nn.Module.__init__ = patched_init | |
| torch.nn.Module.to = patched_to | |
| yield | |
| finally: | |
| torch.nn.Module.__init__ = original_init | |
| torch.nn.Module.to = original_to | |
| start = time.perf_counter() | |
| module_list = set(module_list) | |
| for module in module_list: | |
| module.cpu() | |
| memory_management.soft_empty_cache() | |
| end = time.perf_counter() | |
| print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') | |
| return | |