# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportCallable,Dict,Iterable,Optional,Tupleimporttorchfromtorch.optim.optimizerimportOptimizer
[docs]classLars(Optimizer):r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" <https://arxiv.org/pdf/1708.03888.pdf>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS """def__init__(self,params:Iterable[torch.nn.Parameter],# type: ignore[name-defined]lr:float=1e-3,momentum:float=0.0,dampening:float=0.0,weight_decay:float=0.0,nesterov:bool=False,scale_clip:Optional[Tuple[float,float]]=None,)->None:ifnotisinstance(lr,float)orlr<0.0:raiseValueError(f"Invalid learning rate: {lr}")ifmomentum<0.0:raiseValueError(f"Invalid momentum value: {momentum}")ifweight_decay<0.0:raiseValueError(f"Invalid weight_decay value: {weight_decay}")defaults=dict(lr=lr,momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov)ifnesterovand(momentum<=0ordampening!=0):raiseValueError("Nesterov momentum requires a momentum and zero dampening")super(Lars,self).__init__(params,defaults)# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0.0,10.0)def__setstate__(self,state:Dict[str,torch.Tensor]):super(Lars,self).__setstate__(state)forgroupinself.param_groups:group.setdefault("nesterov",False)@torch.no_grad()defstep(self,closure:Optional[Callable[[],float]]=None)->Optional[float]:"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:weight_decay=group["weight_decay"]momentum=group["momentum"]dampening=group["dampening"]nesterov=group["nesterov"]forpingroup["params"]:ifp.gradisNone:continued_p=p.grad.dataifweight_decay!=0:d_p.add_(p.data,alpha=weight_decay)ifmomentum!=0:param_state=self.state[p]if"momentum_buffer"notinparam_state:buf=param_state["momentum_buffer"]=torch.clone(d_p).detach()else:buf=param_state["momentum_buffer"]buf.mul_(momentum).add_(d_p,alpha=1-dampening)ifnesterov:d_p=d_p.add(buf,alpha=momentum)else:d_p=buf# LARSp_norm=p.data.pow(2).sum().sqrt()update_norm=d_p.pow(2).sum().sqrt()# Compute the local LRifp_norm==0orupdate_norm==0:local_lr=1else:local_lr=p_norm/update_normp.data.add_(d_p,alpha=-group["lr"]*local_lr)returnloss