# Copyright (C) 2019-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.importmathfromtypingimportCallable,Iterable,Optional,Tupleimporttorchfromtorch.optim.optimizerimportOptimizer
[docs]classRaLars(Optimizer):"""Implements the RAdam optimizer from `"On the variance of the Adaptive Learning Rate and Beyond" <https://arxiv.org/pdf/1908.03265.pdf>`_ with optional Layer-wise adaptive Scaling from `"Large Batch Training of Convolutional Networks" <https://arxiv.org/pdf/1708.03888.pdf>`_ Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate betas (Tuple[float, float], optional): coefficients used for running averages (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) force_adaptive_momentum (float, optional): use adaptive momentum if variance is not tractable (default: False) scale_clip (float, optional): the maximal upper bound for the scale factor of LARS """def__init__(self,params:Iterable[torch.nn.Parameter],lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-8,weight_decay:float=0.0,force_adaptive_momentum:bool=False,scale_clip:Optional[Tuple[float,float]]=None,)->None:iflr<0.0:raiseValueError(f"Invalid learning rate: {lr}")ifeps<0.0:raiseValueError(f"Invalid epsilon value: {eps}")ifnot0.0<=betas[0]<1.0:raiseValueError(f"Invalid beta parameter at index 0: {betas[0]}")ifnot0.0<=betas[1]<1.0:raiseValueError(f"Invalid beta parameter at index 1: {betas[1]}")defaults={"lr":lr,"betas":betas,"eps":eps,"weight_decay":weight_decay}super(RaLars,self).__init__(params,defaults)# RAdam tweaksself.force_adaptive_momentum=force_adaptive_momentum# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0,10)@torch.no_grad()defstep(self,closure:Optional[Callable[[],float]]=None)->Optional[float]:# type: ignore[override]"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:# Get group-shared variablesbeta1,beta2=group["betas"]# Compute max length of SMA on first stepifnotisinstance(group.get("sma_inf"),float):group["sma_inf"]=2/(1-beta2)-1sma_inf=group["sma_inf"]forpingroup["params"]:ifp.gradisNone:continuegrad=p.grad.dataifgrad.is_sparse:raiseRuntimeError(f"{self.__class__.__name__} does not support sparse gradients")state=self.state[p]# State initializationiflen(state)==0:state["step"]=0# Exponential moving average of gradient valuesstate["exp_avg"]=torch.zeros_like(p.data)# Exponential moving average of squared gradient valuesstate["exp_avg_sq"]=torch.zeros_like(p.data)exp_avg,exp_avg_sq=state["exp_avg"],state["exp_avg_sq"]state["step"]+=1# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(grad,alpha=1-beta1)exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value=1-beta2)# Bias correctionbias_correction1=1-beta1**state["step"]bias_correction2=1-beta2**state["step"]# Compute length of SMAsma_t=sma_inf-2*state["step"]*(1-bias_correction2)/bias_correction2update=torch.zeros_like(p.data)ifsma_t>4:# Variance rectification termr_t=math.sqrt((sma_t-4)*(sma_t-2)*sma_inf/((sma_inf-4)*(sma_inf-2)*sma_t))# Adaptive momentumupdate.addcdiv_(exp_avg/bias_correction1,(exp_avg_sq/bias_correction2).sqrt().add_(group["eps"]),value=r_t)else:ifself.force_adaptive_momentum:# Adaptive momentum without variance rectification (Adam)update.addcdiv_(exp_avg/bias_correction1,(exp_avg_sq/bias_correction2).sqrt().add_(group["eps"]))else:# Unadapted momentumupdate.add_(exp_avg/bias_correction1)# Weight decayifgroup["weight_decay"]!=0:update.add_(p.data,alpha=group["weight_decay"])# LARSp_norm=p.data.pow(2).sum().sqrt()update_norm=update.pow(2).sum().sqrt()phi_p=p_norm.clamp(*self.scale_clip)# Compute the local LRlocal_lr=1ifphi_p==0orupdate_norm==0elsephi_p/update_normstate["local_lr"]=local_lrp.data.add_(update,alpha=-group["lr"]*local_lr)returnloss