Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """PyTorch BERT model.""" | |
| import logging | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import math | |
| from modules.until_config import PretrainedConfig | |
| logger = logging.getLogger(__name__) | |
| def gelu(x): | |
| """Implementation of the gelu activation function. | |
| For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): | |
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| """ | |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
| def swish(x): | |
| return x * torch.sigmoid(x) | |
| def get_dual_matrix(sim_matrix): | |
| if torch.is_tensor(sim_matrix): | |
| pass | |
| else: | |
| sim_matrix = torch.tensor(sim_matrix) | |
| temp = 1 | |
| # sim_matrix = sim_matrix * F.softmax(sim_matrix / temp, dim=0) * len(sim_matrix) | |
| alpha = F.softmax(sim_matrix / temp, dim=0) | |
| beta = F.softmax(sim_matrix / temp, dim=1) | |
| sim_matrix = sim_matrix * alpha * beta | |
| return sim_matrix | |
| ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |
| class LayerNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-12): | |
| """Construct a layernorm module in the TF style (epsilon inside the square root). | |
| """ | |
| super(LayerNorm, self).__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, x): | |
| u = x.mean(-1, keepdim=True) | |
| s = (x - u).pow(2).mean(-1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
| return self.weight * x + self.bias | |
| class PreTrainedModel(nn.Module): | |
| """ An abstract class to handle weights initialization and | |
| a simple interface for dowloading and loading pretrained models. | |
| """ | |
| def __init__(self, config, *inputs, **kwargs): | |
| super(PreTrainedModel, self).__init__() | |
| if not isinstance(config, PretrainedConfig): | |
| raise ValueError( | |
| "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " | |
| "To create a model from a Google pretrained model use " | |
| "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |
| self.__class__.__name__, self.__class__.__name__ | |
| )) | |
| self.config = config | |
| def init_weights(self, module): | |
| """ Initialize the weights. | |
| """ | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| # Slightly different from the TF version which uses truncated_normal for initialization | |
| # cf https://github.com/pytorch/pytorch/pull/5617 | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| elif isinstance(module, LayerNorm): | |
| if 'beta' in dir(module) and 'gamma' in dir(module): | |
| module.beta.data.zero_() | |
| module.gamma.data.fill_(1.0) | |
| else: | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| def resize_token_embeddings(self, new_num_tokens=None): | |
| raise NotImplementedError | |
| def init_preweight(cls, model, state_dict, prefix=None, task_config=None): | |
| old_keys = [] | |
| new_keys = [] | |
| for key in state_dict.keys(): | |
| new_key = None | |
| if 'gamma' in key: | |
| new_key = key.replace('gamma', 'weight') | |
| if 'beta' in key: | |
| new_key = key.replace('beta', 'bias') | |
| if new_key: | |
| old_keys.append(key) | |
| new_keys.append(new_key) | |
| for old_key, new_key in zip(old_keys, new_keys): | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| if prefix is not None: | |
| old_keys = [] | |
| new_keys = [] | |
| for key in state_dict.keys(): | |
| old_keys.append(key) | |
| new_keys.append(prefix + key) | |
| for old_key, new_key in zip(old_keys, new_keys): | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| missing_keys = [] | |
| unexpected_keys = [] | |
| error_msgs = [] | |
| # copy state_dict so _load_from_state_dict can modify it | |
| metadata = getattr(state_dict, '_metadata', None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| def load(module, prefix=''): | |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
| module._load_from_state_dict( | |
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, prefix + name + '.') | |
| load(model, prefix='') | |
| if prefix is None and (task_config is None or task_config.local_rank == 0): | |
| logger.info("-" * 20) | |
| if len(missing_keys) > 0: | |
| logger.info("Weights of {} not initialized from pretrained model: {}" | |
| .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) | |
| if len(unexpected_keys) > 0: | |
| logger.info("Weights from pretrained model not used in {}: {}" | |
| .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) | |
| if len(error_msgs) > 0: | |
| logger.error("Weights from pretrained model cause errors in {}: {}" | |
| .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) | |
| return model | |
| def dtype(self): | |
| """ | |
| :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |
| """ | |
| try: | |
| return next(self.parameters()).dtype | |
| except StopIteration: | |
| # For nn.DataParallel compatibility in PyTorch 1.5 | |
| def find_tensor_attributes(module: nn.Module): | |
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |
| return tuples | |
| gen = self._named_members(get_members_fn=find_tensor_attributes) | |
| first_tuple = next(gen) | |
| return first_tuple[1].dtype | |
| def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): | |
| """ | |
| Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. | |
| Download and cache the pre-trained model file if needed. | |
| """ | |
| # Instantiate model. | |
| model = cls(config, *inputs, **kwargs) | |
| if state_dict is None: | |
| return model | |
| model = cls.init_preweight(model, state_dict) | |
| return model | |
| ################################## | |
| ###### LOSS FUNCTION ############# | |
| ################################## | |
| class CrossEn(nn.Module): | |
| def __init__(self,): | |
| super(CrossEn, self).__init__() | |
| def forward(self, sim_matrix): | |
| logpt = F.log_softmax(sim_matrix, dim=-1) | |
| logpt = torch.diag(logpt) | |
| nce_loss = -logpt | |
| sim_loss = nce_loss.mean() | |
| return sim_loss | |
| class Dual_CrossEn(nn.Module): | |
| def __init__(self,): | |
| super(Dual_CrossEn, self).__init__() | |
| def forward(self, sim_matrix): | |
| sim_matrix = get_dual_matrix(sim_matrix) | |
| logpt = F.log_softmax(sim_matrix, dim=-1) | |
| logpt = torch.diag(logpt) | |
| nce_loss = -logpt | |
| sim_loss = nce_loss.mean() | |
| return sim_loss | |
| class MILNCELoss(nn.Module): | |
| def __init__(self, batch_size=1, n_pair=1,): | |
| super(MILNCELoss, self).__init__() | |
| self.batch_size = batch_size | |
| self.n_pair = n_pair | |
| torch_v = float(".".join(torch.__version__.split(".")[:2])) | |
| self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 | |
| def forward(self, sim_matrix): | |
| mm_mask = np.eye(self.batch_size) | |
| mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) | |
| mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) | |
| from_text_matrix = sim_matrix + mm_mask * -1e12 | |
| from_video_matrix = sim_matrix.transpose(1, 0) | |
| new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) | |
| logpt = F.log_softmax(new_sim_matrix, dim=-1) | |
| mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) | |
| masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 | |
| new_logpt = -torch.logsumexp(masked_logpt, dim=-1) | |
| logpt_choice = torch.zeros_like(new_logpt) | |
| mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) | |
| logpt_choice[mark_ind] = 1 | |
| sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() | |
| return sim_loss | |
| class MaxMarginRankingLoss(nn.Module): | |
| def __init__(self, | |
| margin=1.0, | |
| negative_weighting=False, | |
| batch_size=1, | |
| n_pair=1, | |
| hard_negative_rate=0.5, | |
| ): | |
| super(MaxMarginRankingLoss, self).__init__() | |
| self.margin = margin | |
| self.n_pair = n_pair | |
| self.batch_size = batch_size | |
| easy_negative_rate = 1 - hard_negative_rate | |
| self.easy_negative_rate = easy_negative_rate | |
| self.negative_weighting = negative_weighting | |
| if n_pair > 1 and batch_size > 1: | |
| alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) | |
| mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha | |
| mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) | |
| mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) | |
| self.mm_mask = mm_mask.float() | |
| def forward(self, x): | |
| d = torch.diag(x) | |
| max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ | |
| F.relu(self.margin + x - d.view(1, -1)) | |
| if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: | |
| max_margin = max_margin * self.mm_mask.to(max_margin.device) | |
| return max_margin.mean() | |
| class AllGather(torch.autograd.Function): | |
| """An autograd function that performs allgather on a tensor.""" | |
| def forward(ctx, tensor, args): | |
| output = [torch.empty_like(tensor) for _ in range(args.world_size)] | |
| torch.distributed.all_gather(output, tensor) | |
| ctx.rank = args.rank | |
| ctx.batch_size = tensor.shape[0] | |
| return torch.cat(output, dim=0) | |
| def backward(ctx, grad_output): | |
| return ( | |
| grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], | |
| None, | |
| ) | |