炼丹技巧-EMA(指数移动平均)
原理
来自于参考文献[1]
基本假设:模型权重在最后的$n$步内,会在实际的最优点处抖动,因此取最后$n$步的平均,能使模型更加鲁棒。
权重参数更新公式:$v_{t} = \beta * v_{t-1} + (1 - \beta) * v_{t}$
代码
来自于参考文献[1],注意只在验证和测试时使用ema的平均参数,训练时不用。
class EMA:
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
实验效果
在微博情绪识别任务[2]中,有微弱提升。
method | performance |
---|---|
without ema | 96.83 |
with ema | 96.99 |
$\Delta$ | 0.16 |
参考文献
[1] 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现
[2] 疫情微博情绪识别挑战赛