Spaces:
Runtime error
Runtime error
| # Code taken and adapted from https://github.com/wagnermoritz/GSE | |
| from vlm_eval.attacks.attack import Attack | |
| import torch | |
| import numpy as np | |
| class PGD0(Attack): | |
| def __init__(self, model, *args, img_range=(0, 1), k=5000, n_restarts=1, | |
| targeted=False, iters=200, stepsize=120000/255.0, eps=4./255.,ver=False,mask_out='none',**kwargs): | |
| ''' | |
| Implementation of the PGD0 attack https://arxiv.org/pdf/1909.05040 | |
| Author's implementation: https://github.com/fra31/sparse-imperceivable-attacks/tree/master | |
| Addapted from: https://github.com/wagnermoritz/GSE/tree/main | |
| args: | |
| model: Callable, PyTorch classifier. | |
| img_range: Tuple of ints/floats, lower and upper bound of image | |
| entries. | |
| targeted: Bool, given label is used as a target label if True. | |
| k: Int, sparsity parameter. | |
| n_restarts: Int, number of restarts from random perturbation. | |
| iters: Int, number of gradient descent steps per restart. | |
| stepsize: Float, step size for gradient descent. | |
| ''' | |
| super().__init__(model, img_range=img_range, targeted=targeted) | |
| self.k = k | |
| self.n_restarts = n_restarts | |
| self.eps = eps | |
| self.iters = iters | |
| self.stepsize = stepsize | |
| if mask_out != 'none': | |
| self.mask_out = mask_out | |
| else: | |
| self.mask_out = None | |
| self.ver = ver | |
| def _set_mask(self, data): | |
| mask = torch.ones_like(data) | |
| if self.mask_out == 'context': | |
| mask[:, :-1, ...] = 0 | |
| elif self.mask_out == 'query': | |
| mask[:, -1, ...] = 0 | |
| elif isinstance(self.mask_out, int): | |
| mask[:, self.mask_out, ...] = 0 | |
| elif self.mask_out is None: | |
| pass | |
| else: | |
| raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') | |
| return mask | |
| def __call__(self, x, *args, **kwargs): | |
| ''' | |
| Perform the PGD_0 attack on a batch of images x. | |
| args: | |
| x: Tensor of shape [B, C, H, W], batch of images. | |
| y: Tensor of shape [B], batch of labels. | |
| Returns a tensor of the same shape as x containing adversarial examples | |
| ''' | |
| for param in self.model.model.parameters(): | |
| param.requires_grad = False | |
| mask_out = self._set_mask(x) | |
| x = x.to(self.device) | |
| B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] | |
| for _ in range(self.n_restarts): | |
| if not len(x): | |
| break | |
| eps = torch.full_like(x, self.eps) | |
| lb, ub = torch.maximum(-eps, -x),torch.minimum(eps, 1.0 - x) #self.img_range[0] - x, self.img_range[1] - x | |
| pert = (torch.clamp(x + (ub - lb) * torch.rand_like(x) + lb, *self.img_range) - x).view(B, C, H, W) * mask_out.view(B, C, H, W) | |
| pert = self.project_L0(pert, lb, ub) # pert is of the shape (B, C, H, W) | |
| for _ in range(self.iters): | |
| pert.requires_grad = True | |
| loss = self.lossfn(x=x, pert=pert.view(*x.shape), mask_out=mask_out) | |
| loss.backward() | |
| if self.ver and _ % 20 == 0: | |
| print(f"Loss: {loss}, Iter: {_}") | |
| grad = pert.grad.data.view(B,C,H,W) * mask_out.view(B, C, H, W) # shape (B, C, H, W) | |
| with torch.no_grad(): | |
| grad /= grad.abs().sum(dim=(1,2,3), keepdim=True) + 1e-10 | |
| pert += (torch.rand_like(x) - .5).view(B, C, H, W) * 1e-12 - self.stepsize * grad | |
| pert = self.project_L0(pert, lb, ub) | |
| return (x + pert.view(*x.shape) * mask_out).detach() | |
| def project_L0_sigma(self, pert, sigma, kappa, x_orig): | |
| B, C, H, W = pert.shape | |
| x = torch.clone(pert) | |
| p1 = (1.0 / torch.maximum(1e-12, sigma) * (x_orig > 0).float()) + \ | |
| (1e12 * (x_orig == 0).float()) | |
| p2 = (1.0 / torch.maximum(torch.tensor(1e-12), sigma)) * \ | |
| (1.0 / torch.maximum(torch.tensor(1e-12), x_orig) - 1) * \ | |
| (x_orig > 0).float() + 1e12 * (x_orig == 0).float() + 1e12 * (sigma == 0).float() | |
| lmbd_l = torch.maximum(-kappa, torch.amax(-p1, dim=1, keepdim=True)) | |
| lmbd_u = torch.minimum(kappa, torch.amin(p2, dim=1, keepdim=True)) | |
| lmbd_unconstr = torch.sum((pert - x_orig) * sigma * x_orig, dim=1, keepdim=True) / torch.clamp(torch.sum((sigma * x_orig) ** 2, dim=1, keepdim=True), min=1e-12) | |
| lmbd = torch.maximum(lmbd_l, torch.minimum(lmbd_unconstr, lmbd_u)) | |
| return 0 | |
| def project_L0(self, pert, lb, ub): | |
| ''' | |
| Project a batch of perturbations such that at most self.k pixels | |
| are perturbed and componentwise there holds lb <= pert <= ub. | |
| ''' | |
| B, C, H, W = pert.shape # Here, pert is of the shape B, C, H, W | |
| p1 = torch.sum(pert ** 2, dim=1) | |
| p2 = torch.clamp(torch.minimum(ub.view(B, C, H, W) - pert, pert - lb.view(B, C, H, W)), 0) | |
| p2 = torch.sum(p2 ** 2, dim=1) | |
| p3 = torch.topk(-1 * (p1 - p2).view(p1.size(0), -1), k=H*W-self.k, dim=-1)[1] | |
| pert = torch.maximum(torch.minimum(pert, ub.view(B, C, H, W)), lb.view(B, C, H, W)) | |
| pert[torch.arange(0, B).view(-1, 1), :, p3//W, p3%H] = 0 | |
| return pert | |
| def lossfn(self, x, pert, mask_out): | |
| ''' | |
| Compute the loss at x. | |
| ''' | |
| return (2 * self.targeted - 1) * self.model(x + pert * mask_out).sum() | |