File size: 5,580 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Code taken and adapted from https://github.com/wagnermoritz/GSE

from vlm_eval.attacks.attack import Attack
import torch
import numpy as np

class PGD0(Attack):
    def __init__(self, model, *args, img_range=(0, 1), k=5000, n_restarts=1,
                 targeted=False, iters=200, stepsize=120000/255.0, eps=4./255.,ver=False,mask_out='none',**kwargs):
        '''
        Implementation of the PGD0 attack https://arxiv.org/pdf/1909.05040
        Author's implementation: https://github.com/fra31/sparse-imperceivable-attacks/tree/master
        Addapted from: https://github.com/wagnermoritz/GSE/tree/main

        args:
        model:         Callable, PyTorch classifier.
        img_range:     Tuple of ints/floats, lower and upper bound of image
                       entries.
        targeted:      Bool, given label is used as a target label if True.
        k:             Int, sparsity parameter.
        n_restarts:    Int, number of restarts from random perturbation.
        iters:         Int, number of gradient descent steps per restart.
        stepsize:      Float, step size for gradient descent.
        '''
        super().__init__(model, img_range=img_range, targeted=targeted)
        self.k = k
        self.n_restarts = n_restarts
        self.eps = eps
        self.iters = iters
        self.stepsize = stepsize
        if mask_out != 'none':
            self.mask_out = mask_out
        else:
            self.mask_out = None
        self.ver = ver

    def _set_mask(self, data):
        mask = torch.ones_like(data)
        if self.mask_out == 'context':
            mask[:, :-1, ...] = 0
        elif self.mask_out == 'query':
            mask[:, -1, ...] = 0
        elif isinstance(self.mask_out, int):
            mask[:, self.mask_out, ...] = 0
        elif self.mask_out is None:
            pass
        else:
            raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
        return mask


    def __call__(self, x, *args, **kwargs):
        '''
        Perform the PGD_0 attack on a batch of images x.

        args:
        x:   Tensor of shape [B, C, H, W], batch of images.
        y:   Tensor of shape [B], batch of labels.

        Returns a tensor of the same shape as x containing adversarial examples
        '''

        for param in self.model.model.parameters():
            param.requires_grad = False
        
        mask_out = self._set_mask(x)
        x = x.to(self.device)
        B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]

        for _ in range(self.n_restarts):
            if not len(x):
                break
            eps = torch.full_like(x, self.eps)
            lb, ub = torch.maximum(-eps, -x),torch.minimum(eps, 1.0 - x) #self.img_range[0] - x, self.img_range[1] - x
            pert = (torch.clamp(x + (ub - lb) * torch.rand_like(x) + lb, *self.img_range) - x).view(B, C, H, W) * mask_out.view(B, C, H, W)
            pert = self.project_L0(pert, lb, ub) # pert is of the shape (B, C, H, W)

            for _ in range(self.iters):
                pert.requires_grad = True
                loss = self.lossfn(x=x, pert=pert.view(*x.shape), mask_out=mask_out)
                loss.backward()

                if self.ver and _ % 20 == 0:
                    print(f"Loss: {loss}, Iter: {_}")
                
                grad = pert.grad.data.view(B,C,H,W) * mask_out.view(B, C, H, W) # shape (B, C, H, W)
                with torch.no_grad():
                    grad /= grad.abs().sum(dim=(1,2,3), keepdim=True) + 1e-10
                    pert += (torch.rand_like(x) - .5).view(B, C, H, W) * 1e-12 - self.stepsize * grad
                    pert = self.project_L0(pert, lb, ub)
        
        return (x + pert.view(*x.shape) * mask_out).detach()
    

    def project_L0_sigma(self, pert, sigma, kappa, x_orig):

        B, C, H, W = pert.shape
        x = torch.clone(pert)
        p1 = (1.0 / torch.maximum(1e-12, sigma) * (x_orig > 0).float()) + \
             (1e12 * (x_orig == 0).float())
        p2 = (1.0 / torch.maximum(torch.tensor(1e-12), sigma)) * \
             (1.0 / torch.maximum(torch.tensor(1e-12), x_orig) - 1) * \
             (x_orig > 0).float() + 1e12 * (x_orig == 0).float() + 1e12 * (sigma == 0).float()
        lmbd_l = torch.maximum(-kappa, torch.amax(-p1, dim=1, keepdim=True))
        lmbd_u = torch.minimum(kappa, torch.amin(p2, dim=1, keepdim=True)) 

        lmbd_unconstr = torch.sum((pert - x_orig) * sigma * x_orig, dim=1, keepdim=True) / torch.clamp(torch.sum((sigma * x_orig) ** 2, dim=1, keepdim=True), min=1e-12)
        lmbd = torch.maximum(lmbd_l, torch.minimum(lmbd_unconstr, lmbd_u))
        return 0


    def project_L0(self, pert, lb, ub):
        '''
        Project a batch of perturbations such that at most self.k pixels
        are perturbed and componentwise there holds lb <= pert <= ub.
        '''
        
        B, C, H, W = pert.shape # Here, pert is of the shape B, C, H, W
        p1 = torch.sum(pert ** 2, dim=1)
        p2 = torch.clamp(torch.minimum(ub.view(B, C, H, W) - pert, pert - lb.view(B, C, H, W)), 0)
        p2 = torch.sum(p2 ** 2, dim=1)
        p3 = torch.topk(-1 * (p1 - p2).view(p1.size(0), -1), k=H*W-self.k, dim=-1)[1] 
        pert = torch.maximum(torch.minimum(pert, ub.view(B, C, H, W)), lb.view(B, C, H, W))
        pert[torch.arange(0, B).view(-1, 1), :, p3//W, p3%H] = 0  
        return pert
        
    def lossfn(self, x, pert, mask_out):
        '''
        Compute the loss at x.
        '''
        return (2 * self.targeted - 1) * self.model(x + pert * mask_out).sum()