import torch import torch.nn.functional as F import collections.abc as container_abcs # Code taken from https://github.com/chs20/RobustVLM/tree/main # some parts of this code are adapted from # https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py def project_perturbation(perturbation, eps, norm): if norm in ['inf', 'linf', 'Linf']: pert_normalized = torch.clamp(perturbation, -eps, eps) return pert_normalized elif norm in [2, 2.0, 'l2', 'L2', '2']: pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps) return pert_normalized else: raise NotImplementedError(f'Norm {norm} not supported') def normalize_grad(grad, p): if p in ['inf', 'linf', 'Linf']: return grad.sign() elif p in [2, 2.0, 'l2', 'L2', '2']: bs = grad.shape[0] grad_flat = grad.view(bs, -1) grad_normalized = F.normalize(grad_flat, p=2, dim=1) return grad_normalized.view_as(grad) def L1_norm(x, keepdim=False): z = x.abs().view(x.shape[0], -1).sum(-1) if keepdim: z = z.view(-1, *[1]*(len(x.shape) - 1)) return z def L2_norm(x, keepdim=False): z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() if keepdim: z = z.view(-1, *[1]*(len(x.shape) - 1)) return z def L0_norm(x): return (x != 0.).view(x.shape[0], -1).sum(-1) def zero_gradients(x): if isinstance(x, torch.Tensor): if x.grad is not None: x.grad.detach_() x.grad.zero_() elif isinstance(x, container_abcs.Iterable): for elem in x: zero_gradients(elem)