# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportCallable,Iterable,Optional,Tupleimporttorchfromtorch.optim.optimizerimportOptimizer
[docs]classLamb(Optimizer):"""Implements the Lamb optimizer from `"Large batch optimization for deep learning: training BERT in 76 minutes" <https://arxiv.org/pdf/1904.00962v3.pdf>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate betas (Tuple[float, float], optional): beta coefficients used for running averages (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS """def__init__(self,params:Iterable[torch.nn.Parameter],# type: ignore[name-defined]lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-8,weight_decay:float=0.0,scale_clip:Optional[Tuple[float,float]]=None,)->None:iflr<0.0:raiseValueError(f"Invalid learning rate: {lr}")ifeps<0.0:raiseValueError(f"Invalid epsilon value: {eps}")ifnot0.0<=betas[0]<1.0:raiseValueError(f"Invalid beta parameter at index 0: {betas[0]}")ifnot0.0<=betas[1]<1.0:raiseValueError(f"Invalid beta parameter at index 1: {betas[1]}")defaults=dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay)super(Lamb,self).__init__(params,defaults)# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0.0,10.0)@torch.no_grad()defstep(self,closure:Optional[Callable[[],float]]=None)->Optional[float]:"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:forpingroup["params"]:ifp.gradisNone:continuegrad=p.grad.dataifgrad.is_sparse:raiseRuntimeError(f"{self.__class__.__name__} does not support sparse gradients")state=self.state[p]# State initializationiflen(state)==0:state["step"]=0# Exponential moving average of gradient valuesstate["exp_avg"]=torch.zeros_like(p.data)# Exponential moving average of squared gradient valuesstate["exp_avg_sq"]=torch.zeros_like(p.data)exp_avg,exp_avg_sq=state["exp_avg"],state["exp_avg_sq"]beta1,beta2=group["betas"]state["step"]+=1# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(grad,alpha=1-beta1)exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value=1-beta2)# Gradient term correctionupdate=torch.zeros_like(p.data)denom=exp_avg_sq.sqrt().add_(group["eps"])update.addcdiv_(exp_avg,denom)# Weight decayifgroup["weight_decay"]!=0:update.add_(p.data,alpha=group["weight_decay"])# LARSp_norm=p.data.pow(2).sum().sqrt()update_norm=update.pow(2).sum().sqrt()phi_p=p_norm.clamp(*self.scale_clip)# Compute the local LRifphi_p==0orupdate_norm==0:local_lr=1else:local_lr=phi_p/update_normstate["local_lr"]=local_lrp.data.add_(update,alpha=-group["lr"]*local_lr)returnloss