炼丹技巧-EMA(指数移动平均)

原理

来自于参考文献[1]

基本假设:模型权重在最后的$n$步内,会在实际的最优点处抖动,因此取最后$n$步的平均,能使模型更加鲁棒。

权重参数更新公式:$v_{t} = \beta * v_{t-1} + (1 - \beta) * v_{t}$

代码

来自于参考文献[1],注意只在验证和测试时使用ema的平均参数,训练时不用。

class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]
    
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

实验效果

在微博情绪识别任务[2]中,有微弱提升。

method performance
without ema 96.83
with ema 96.99
$\Delta$ 0.16

参考文献

[1] 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

[2] 疫情微博情绪识别挑战赛

Chuanbo Zhu
Chuanbo Zhu
PhD Candidate of Computer Science and Technology

My research interests include multimodal intelligence, sentiment analysis, emotion recognition and sarcasm detection

Related