# -*- coding: utf-8 -*-'''Rectified Adam optimizer'''importmathimporttorchfromtorch.optim.optimizerimportOptimizer
[docs]classRaLars(Optimizer):"""Implements the RAdam optimizer from https://arxiv.org/pdf/1908.03265.pdf with optional Layer-wise adaptive Scaling from https://arxiv.org/pdf/1708.03888.pdf Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate betas (Tuple[float, float], optional): coefficients used for running averages (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) force_adaptive_momentum (float, optional): use adaptive momentum if variance is not tractable (default: False) scale_clip (float, optional): the maximal upper bound for the scale factor of LARS """def__init__(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-8,weight_decay=0,force_adaptive_momentum=False,scale_clip=None):ifnot0.0<=lr:raiseValueError("Invalid learning rate: {}".format(lr))ifnot0.0<=eps:raiseValueError("Invalid epsilon value: {}".format(eps))ifnot0.0<=betas[0]<1.0:raiseValueError("Invalid beta parameter at index 0: {}".format(betas[0]))ifnot0.0<=betas[1]<1.0:raiseValueError("Invalid beta parameter at index 1: {}".format(betas[1]))defaults=dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay)super(RaLars,self).__init__(params,defaults)# RAdam tweaksself.force_adaptive_momentum=force_adaptive_momentum# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0,10)defstep(self,closure=None):"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:loss=closure()forgroupinself.param_groups:# Get group-shared variablesbeta1,beta2=group['betas']sma_inf=group.get('sma_inf')# Compute max length of SMA on first stepifnotisinstance(sma_inf,float):group['sma_inf']=2/(1-beta2)-1sma_inf=group.get('sma_inf')forpingroup['params']:ifp.gradisNone:continuegrad=p.grad.dataifgrad.is_sparse:raiseRuntimeError('RAdam does not support sparse gradients')state=self.state[p]# State initializationiflen(state)==0:state['step']=0# Exponential moving average of gradient valuesstate['exp_avg']=torch.zeros_like(p.data)# Exponential moving average of squared gradient valuesstate['exp_avg_sq']=torch.zeros_like(p.data)exp_avg,exp_avg_sq=state['exp_avg'],state['exp_avg_sq']state['step']+=1# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(1-beta1,grad)exp_avg_sq.mul_(beta2).addcmul_(1-beta2,grad,grad)# Bias correctionbias_correction1=1-beta1**state['step']bias_correction2=1-beta2**state['step']# Compute length of SMAsma_t=sma_inf-2*state['step']*(1-bias_correction2)/bias_correction2update=torch.zeros_like(p.data)ifsma_t>4:# Variance rectification termr_t=math.sqrt((sma_t-4)*(sma_t-2)*sma_inf/((sma_inf-4)*(sma_inf-2)*sma_t))# Adaptive momentumupdate.addcdiv_(r_t,exp_avg/bias_correction1,(exp_avg_sq/bias_correction2).sqrt().add_(group['eps']))else:ifself.force_adaptive_momentum:# Adaptive momentum without variance rectification (Adam)update.addcdiv_(1,exp_avg/bias_correction1,(exp_avg_sq/bias_correction2).sqrt().add_(group['eps']))else:# Unadapted momentumupdate.add_(exp_avg/bias_correction1)# Weight decayifgroup['weight_decay']!=0:update.add_(group['weight_decay'],p.data)# LARSp_norm=p.data.pow(2).sum().sqrt()update_norm=update.pow(2).sum().sqrt()phi_p=p_norm.clamp(*self.scale_clip)# Compute the local LRifphi_p==0orupdate_norm==0:local_lr=1else:local_lr=phi_p/update_normstate['local_lr']=local_lrp.data.add_(-group['lr']*local_lr,update)returnloss