炼丹技巧-对抗训练

FGM

class FGM:
    """
    参考自: https://blog.csdn.net/qq_40176087/article/details/121512229
    FGSM的更新公式为: eplison * torch.sign(param.grad)
    """

    def __init__(self, model, eps=1.) -> None:
        self.model = model
        self.eps = eps
        self.backup = {}

    # only attack word embedding
    def attack(self, embedding_name="word_embeddings"):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embedding_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm and not torch.isnan(norm):
                    r_at = self.eps * param.grad / norm
                    param.data.add_(r_at)

    def restore(self, embedding_name="word_embeddings"):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embedding_name in name:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

FGSM

class FGSM:
    """
    参考自: https://blog.csdn.net/qq_40176087/article/details/121512229
    FGSM的更新公式为: eplison * torch.sign(param.grad)
    """

    def __init__(self, model, eps=1.) -> None:
        self.model = model
        self.eps = eps
        self.backup = {}

    # only attack word embedding
    def attack(self, embedding_name="word_embeddings"):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embedding_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm and not torch.isnan(norm):
                    r_at = self.eps * torch.sign(param.grad)
                    param.data.add_(r_at)

    def restore(self, embedding_name="word_embeddings"):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embedding_name in name:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

PGD

class PGD:

    def __init__(self, model, eps=1., alpha=0.3) -> None:
        self.model = model
        self.eps = eps
        self.alpha = alpha
        self.emb_backup = {}
        self.grad_backup = {}

    def attack(self, embedding_name="word_embeddings", is_first_attack=False):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embedding_name in name:
                if is_first_attack:
                    self.emb_backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0 and not torch.isnan(norm):
                    r_at = self.alpha * param.data / norm
                    param.data.add_(r_at)
                    param.data = self.project(name, param.data)

    def restore(self, embedding_name="word_embeddings"):
        for name, param in self.model.named_parameters():
            if param.requires_grad and embedding_name in name:
                assert name in self.emb_backup
                param.data = self.emb_backup[name]
        self.emb_backup = {}

    def project(self, param_name, param_data):
        r = param_data - self.emb_backup[param_name]
        if torch.norm(r) > self.eps:
            r = self.eps * r / torch.norm(r)
        return self.emb_backup[param_name] + r

    def backup_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                self.grad_backup[name] = param.grad.clone()

    def restore_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                param.grad = self.grad_backup[name]

实验效果

在微博情绪识别任务[2]中,有微弱提升。

method performance
without adv 96.99
with fgm 97.12
$\Delta$ 0.13
with fgsm 训练失败,验证集准确率下降到52,考虑梯度问题

参考文献

[1] 对抗训练fgm、fgsm和pgd原理和源码分析

[2] 疫情微博情绪识别挑战赛

Chuanbo Zhu
Chuanbo Zhu
PhD Candidate of Computer Science and Technology

My research interests include multimodal intelligence, sentiment analysis, emotion recognition and sarcasm detection

Related