# -*- coding: utf-8 -*-'''Rectified Adam optimizer'''importtorchfromtorch.optim.optimizerimportOptimizer
[docs]classLamb(Optimizer):"""Implements the Lamb optimizer from `"Large batch optimization for deep learning: training BERT in 76 minutes" <https://arxiv.org/pdf/1904.00962v3.pdf>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate betas (Tuple[float, float], optional): beta coefficients used for running averages (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS """def__init__(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-8,weight_decay=0,scale_clip=None):ifnot0.0<=lr:raiseValueError("Invalid learning rate: {}".format(lr))ifnot0.0<=eps:raiseValueError("Invalid epsilon value: {}".format(eps))ifnot0.0<=betas[0]<1.0:raiseValueError("Invalid beta parameter at index 0: {}".format(betas[0]))ifnot0.0<=betas[1]<1.0:raiseValueError("Invalid beta parameter at index 1: {}".format(betas[1]))defaults=dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay)super(Lamb,self).__init__(params,defaults)# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0,10)defstep(self,closure=None):"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:loss=closure()forgroupinself.param_groups:forpingroup['params']:ifp.gradisNone:continuegrad=p.grad.dataifgrad.is_sparse:raiseRuntimeError('RAdam does not support sparse gradients')state=self.state[p]# State initializationiflen(state)==0:state['step']=0# Exponential moving average of gradient valuesstate['exp_avg']=torch.zeros_like(p.data)# Exponential moving average of squared gradient valuesstate['exp_avg_sq']=torch.zeros_like(p.data)exp_avg,exp_avg_sq=state['exp_avg'],state['exp_avg_sq']beta1,beta2=group['betas']state['step']+=1# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(1-beta1,grad)exp_avg_sq.mul_(beta2).addcmul_(1-beta2,grad,grad)# Gradient term correctionupdate=torch.zeros_like(p.data)denom=exp_avg_sq.sqrt().add_(group['eps'])update.addcdiv_(1,exp_avg,denom)# Weight decayifgroup['weight_decay']!=0:update.add_(group['weight_decay'],p.data)# LARSp_norm=p.data.pow(2).sum().sqrt()update_norm=update.pow(2).sum().sqrt()phi_p=p_norm.clamp(*self.scale_clip)# Compute the local LRifphi_p==0orupdate_norm==0:local_lr=1else:local_lr=phi_p/update_normstate['local_lr']=local_lrp.data.add_(-group['lr']*local_lr,update)returnloss