# Copyright (C) 2019-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportCallable,Dict,Iterable,Optional,Tupleimporttorchfromtorch.optim.optimizerimportOptimizer__all__=["LARS"]
[docs]classLARS(Optimizer):r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" <https://arxiv.org/pdf/1708.03888.pdf>`_. The estimation of global and local learning rates is described as follows, :math:`\forall t \geq 1`: .. math:: \alpha_t \leftarrow \alpha (1 - t / T)^2 \\ \gamma_t \leftarrow \frac{\lVert \theta_t \rVert}{\lVert g_t \rVert + \lambda \lVert \theta_t \rVert} where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value), :math:`g_t` is the gradient of :math:`\theta_t`, :math:`T` is the total number of steps, :math:`\alpha` is the learning rate :math:`\lambda \geq 0` is the weight decay. Then we estimate the momentum using: .. math:: v_t \leftarrow m v_{t-1} + \alpha_t \gamma_t (g_t + \lambda \theta_t) where :math:`m` is the momentum and :math:`v_0 = 0`. And finally the update step is performed using the following rule: .. math:: \theta_t \leftarrow \theta_{t-1} - v_t Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS """def__init__(self,params:Iterable[torch.nn.Parameter],lr:float=1e-3,momentum:float=0.0,dampening:float=0.0,weight_decay:float=0.0,nesterov:bool=False,scale_clip:Optional[Tuple[float,float]]=None,)->None:ifnotisinstance(lr,float)orlr<0.0:raiseValueError(f"Invalid learning rate: {lr}")ifmomentum<0.0:raiseValueError(f"Invalid momentum value: {momentum}")ifweight_decay<0.0:raiseValueError(f"Invalid weight_decay value: {weight_decay}")defaults={"lr":lr,"momentum":momentum,"dampening":dampening,"weight_decay":weight_decay,"nesterov":nesterov,}ifnesterovand(momentum<=0ordampening!=0):raiseValueError("Nesterov momentum requires a momentum and zero dampening")super().__init__(params,defaults)# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0.0,10.0)def__setstate__(self,state:Dict[str,torch.Tensor])->None:super().__setstate__(state)forgroupinself.param_groups:group.setdefault("nesterov",False)@torch.no_grad()defstep(self,closure:Optional[Callable[[],float]]=None)->Optional[float]:# type: ignore[override]"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:weight_decay=group["weight_decay"]momentum=group["momentum"]dampening=group["dampening"]nesterov=group["nesterov"]forpingroup["params"]:ifp.gradisNone:continued_p=p.grad.data# LARSp_norm=torch.norm(p.data)denom=torch.norm(d_p)ifweight_decay!=0:d_p.add_(p.data,alpha=weight_decay)denom.add_(p_norm,alpha=weight_decay)# Compute the local LRlocal_lr=1ifp_norm==0ordenom==0elsep_norm/denomifmomentum==0:p.data.add_(d_p,alpha=-group["lr"]*local_lr)else:param_state=self.state[p]if"momentum_buffer"notinparam_state:momentum_buffer=param_state["momentum_buffer"]=torch.clone(d_p).detach()else:momentum_buffer=param_state["momentum_buffer"]momentum_buffer.mul_(momentum).add_(d_p,alpha=1-dampening)d_p=d_p.add(momentum_buffer,alpha=momentum)ifnesterovelsemomentum_bufferp.data.add_(d_p,alpha=-group["lr"]*local_lr)self.state[p]["momentum_buffer"]=momentum_bufferreturnloss