-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathadversarial_loss.py
94 lines (77 loc) · 3.69 KB
/
adversarial_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# ====================================================
# Copyright (C) 2021 All rights reserved.
#
# Author : Xinyu Zhu
# Email : [email protected]
# File Name : AdversarialLoss.py
# Last Modified : 2021-11-07 19:24
# Describe : from https://github.com/SihengLi99/AAAI-SDU-Task1
#
# ====================================================
import torch
import torch.nn.functional as F
class AdversarialLoss(object):
def __init__(self, args) -> None:
super().__init__()
self.args = args
# * divergence function
self.divergence = getattr(self, args.divergence)
def __call__(self, model, logits, train_inputs):
# * get disturbed inputs
inputs_embeds = model.bert.embeddings.word_embeddings(input_ids)
noise = inputs_embeds.clone().detach().normal_(
0, 1).requires_grad_(True) * self.args.noise_var
# * adv loop
for i in range(self.args.adv_nloop):
inputs_embeds = inputs_embeds.detach() + noise
adv_logits = model(attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds)
adv_loss = self.divergence(adv_logits,
logits.detach(),
reduction='batchmean')
# * now we need to find the best noise according to gradient
# * theoretically we need the max, to be more efficient, we
# * approximate with it by gradient assent
noise_grad = torch.autograd.grad(outputs=adv_loss, inputs=noise, retain_graph=True)[0]
noise = noise + noise_grad * self.args.adv_step_size
# * normalization 这里的noise之后好像都没用到, 故注释掉
# noise = self.adv_project(noise,
# norm_type=self.args.project_norm_type,
# eps=self.args.noise_gamma)
adv_loss = self.divergence(adv_logits, logits)
return adv_loss
@staticmethod
def adv_project(grad, norm_type='inf', eps=1e-6):
if norm_type == 'l2':
direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
elif norm_type == 'l1':
direction = grad.sign()
else:
direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
return direction
@staticmethod
def kl(input, target, reduction="sum"):
input = input.float()
target = target.float()
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32),
F.softmax(target, dim=-1, dtype=torch.float32),
reduction=reduction)
return loss
@staticmethod
def sym_kl(input, target, reduction="sum"):
input = input.float()
target = target.float()
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), F.softmax(target.detach(), dim=-1, dtype=torch.float32), reduction=reduction) + \
F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), F.softmax(input.detach(), dim=-1, dtype=torch.float32), reduction=reduction)
return loss
@staticmethod
def js(input, target, reduction="sum"):
input = input.float()
target = target.float()
m = F.softmax(target.detach(), dim=-1, dtype=torch.float32) + \
F.softmax(input.detach(), dim=-1, dtype=torch.float32)
m = 0.5 * m
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), m, reduction=reduction) + \
F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), m, reduction=reduction)
return loss