Spaces:
Runtime error
Runtime error
File size: 5,580 Bytes
fc0ff8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Code taken and adapted from https://github.com/wagnermoritz/GSE
from vlm_eval.attacks.attack import Attack
import torch
import numpy as np
class PGD0(Attack):
def __init__(self, model, *args, img_range=(0, 1), k=5000, n_restarts=1,
targeted=False, iters=200, stepsize=120000/255.0, eps=4./255.,ver=False,mask_out='none',**kwargs):
'''
Implementation of the PGD0 attack https://arxiv.org/pdf/1909.05040
Author's implementation: https://github.com/fra31/sparse-imperceivable-attacks/tree/master
Addapted from: https://github.com/wagnermoritz/GSE/tree/main
args:
model: Callable, PyTorch classifier.
img_range: Tuple of ints/floats, lower and upper bound of image
entries.
targeted: Bool, given label is used as a target label if True.
k: Int, sparsity parameter.
n_restarts: Int, number of restarts from random perturbation.
iters: Int, number of gradient descent steps per restart.
stepsize: Float, step size for gradient descent.
'''
super().__init__(model, img_range=img_range, targeted=targeted)
self.k = k
self.n_restarts = n_restarts
self.eps = eps
self.iters = iters
self.stepsize = stepsize
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
self.ver = ver
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __call__(self, x, *args, **kwargs):
'''
Perform the PGD_0 attack on a batch of images x.
args:
x: Tensor of shape [B, C, H, W], batch of images.
y: Tensor of shape [B], batch of labels.
Returns a tensor of the same shape as x containing adversarial examples
'''
for param in self.model.model.parameters():
param.requires_grad = False
mask_out = self._set_mask(x)
x = x.to(self.device)
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
for _ in range(self.n_restarts):
if not len(x):
break
eps = torch.full_like(x, self.eps)
lb, ub = torch.maximum(-eps, -x),torch.minimum(eps, 1.0 - x) #self.img_range[0] - x, self.img_range[1] - x
pert = (torch.clamp(x + (ub - lb) * torch.rand_like(x) + lb, *self.img_range) - x).view(B, C, H, W) * mask_out.view(B, C, H, W)
pert = self.project_L0(pert, lb, ub) # pert is of the shape (B, C, H, W)
for _ in range(self.iters):
pert.requires_grad = True
loss = self.lossfn(x=x, pert=pert.view(*x.shape), mask_out=mask_out)
loss.backward()
if self.ver and _ % 20 == 0:
print(f"Loss: {loss}, Iter: {_}")
grad = pert.grad.data.view(B,C,H,W) * mask_out.view(B, C, H, W) # shape (B, C, H, W)
with torch.no_grad():
grad /= grad.abs().sum(dim=(1,2,3), keepdim=True) + 1e-10
pert += (torch.rand_like(x) - .5).view(B, C, H, W) * 1e-12 - self.stepsize * grad
pert = self.project_L0(pert, lb, ub)
return (x + pert.view(*x.shape) * mask_out).detach()
def project_L0_sigma(self, pert, sigma, kappa, x_orig):
B, C, H, W = pert.shape
x = torch.clone(pert)
p1 = (1.0 / torch.maximum(1e-12, sigma) * (x_orig > 0).float()) + \
(1e12 * (x_orig == 0).float())
p2 = (1.0 / torch.maximum(torch.tensor(1e-12), sigma)) * \
(1.0 / torch.maximum(torch.tensor(1e-12), x_orig) - 1) * \
(x_orig > 0).float() + 1e12 * (x_orig == 0).float() + 1e12 * (sigma == 0).float()
lmbd_l = torch.maximum(-kappa, torch.amax(-p1, dim=1, keepdim=True))
lmbd_u = torch.minimum(kappa, torch.amin(p2, dim=1, keepdim=True))
lmbd_unconstr = torch.sum((pert - x_orig) * sigma * x_orig, dim=1, keepdim=True) / torch.clamp(torch.sum((sigma * x_orig) ** 2, dim=1, keepdim=True), min=1e-12)
lmbd = torch.maximum(lmbd_l, torch.minimum(lmbd_unconstr, lmbd_u))
return 0
def project_L0(self, pert, lb, ub):
'''
Project a batch of perturbations such that at most self.k pixels
are perturbed and componentwise there holds lb <= pert <= ub.
'''
B, C, H, W = pert.shape # Here, pert is of the shape B, C, H, W
p1 = torch.sum(pert ** 2, dim=1)
p2 = torch.clamp(torch.minimum(ub.view(B, C, H, W) - pert, pert - lb.view(B, C, H, W)), 0)
p2 = torch.sum(p2 ** 2, dim=1)
p3 = torch.topk(-1 * (p1 - p2).view(p1.size(0), -1), k=H*W-self.k, dim=-1)[1]
pert = torch.maximum(torch.minimum(pert, ub.view(B, C, H, W)), lb.view(B, C, H, W))
pert[torch.arange(0, B).view(-1, 1), :, p3//W, p3%H] = 0
return pert
def lossfn(self, x, pert, mask_out):
'''
Compute the loss at x.
'''
return (2 * self.targeted - 1) * self.model(x + pert * mask_out).sum()
|