炼丹技巧-对抗训练
FGM
class FGM:
"""
参考自: https://blog.csdn.net/qq_40176087/article/details/121512229
FGSM的更新公式为: eplison * torch.sign(param.grad)
"""
def __init__(self, model, eps=1.) -> None:
self.model = model
self.eps = eps
self.backup = {}
# only attack word embedding
def attack(self, embedding_name="word_embeddings"):
for name, param in self.model.named_parameters():
if param.requires_grad and embedding_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm and not torch.isnan(norm):
r_at = self.eps * param.grad / norm
param.data.add_(r_at)
def restore(self, embedding_name="word_embeddings"):
for name, param in self.model.named_parameters():
if param.requires_grad and embedding_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
FGSM
class FGSM:
"""
参考自: https://blog.csdn.net/qq_40176087/article/details/121512229
FGSM的更新公式为: eplison * torch.sign(param.grad)
"""
def __init__(self, model, eps=1.) -> None:
self.model = model
self.eps = eps
self.backup = {}
# only attack word embedding
def attack(self, embedding_name="word_embeddings"):
for name, param in self.model.named_parameters():
if param.requires_grad and embedding_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm and not torch.isnan(norm):
r_at = self.eps * torch.sign(param.grad)
param.data.add_(r_at)
def restore(self, embedding_name="word_embeddings"):
for name, param in self.model.named_parameters():
if param.requires_grad and embedding_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
PGD
class PGD:
def __init__(self, model, eps=1., alpha=0.3) -> None:
self.model = model
self.eps = eps
self.alpha = alpha
self.emb_backup = {}
self.grad_backup = {}
def attack(self, embedding_name="word_embeddings", is_first_attack=False):
for name, param in self.model.named_parameters():
if param.requires_grad and embedding_name in name:
if is_first_attack:
self.emb_backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = self.alpha * param.data / norm
param.data.add_(r_at)
param.data = self.project(name, param.data)
def restore(self, embedding_name="word_embeddings"):
for name, param in self.model.named_parameters():
if param.requires_grad and embedding_name in name:
assert name in self.emb_backup
param.data = self.emb_backup[name]
self.emb_backup = {}
def project(self, param_name, param_data):
r = param_data - self.emb_backup[param_name]
if torch.norm(r) > self.eps:
r = self.eps * r / torch.norm(r)
return self.emb_backup[param_name] + r
def backup_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
self.grad_backup[name] = param.grad.clone()
def restore_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
param.grad = self.grad_backup[name]
实验效果
在微博情绪识别任务[2]中,有微弱提升。
method | performance |
---|---|
without adv |
96.99 |
with fgm |
97.12 |
$\Delta$ | 0.13 |
with fgsm |
训练失败,验证集准确率下降到52,考虑梯度问题 |
参考文献
[2] 疫情微博情绪识别挑战赛