[docs]classLars(Optimizer):r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" <https://arxiv.org/pdf/1708.03888.pdf>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS """def__init__(self,params,lr=1e-3,momentum=0,dampening=0,weight_decay=0,nesterov=False,scale_clip=None):ifnotisinstance(lr,float)orlr<0.0:raiseValueError("Invalid learning rate: {}".format(lr))ifmomentum<0.0:raiseValueError("Invalid momentum value: {}".format(momentum))ifweight_decay<0.0:raiseValueError("Invalid weight_decay value: {}".format(weight_decay))defaults=dict(lr=lr,momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov)ifnesterovand(momentum<=0ordampening!=0):raiseValueError("Nesterov momentum requires a momentum and zero dampening")super(Lars,self).__init__(params,defaults)# LARS argumentsself.scale_clip=scale_clipifself.scale_clipisNone:self.scale_clip=(0,10)def__setstate__(self,state):super(Lars,self).__setstate__(state)forgroupinself.param_groups:group.setdefault('nesterov',False)@torch.no_grad()defstep(self,closure=None):"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:weight_decay=group['weight_decay']momentum=group['momentum']dampening=group['dampening']nesterov=group['nesterov']forpingroup['params']:ifp.gradisNone:continued_p=p.grad.dataifweight_decay!=0:d_p.add_(p.data,alpha=weight_decay)ifmomentum!=0:param_state=self.state[p]if'momentum_buffer'notinparam_state:buf=param_state['momentum_buffer']=torch.clone(d_p).detach()else:buf=param_state['momentum_buffer']buf.mul_(momentum).add_(d_p,alpha=1-dampening)ifnesterov:d_p=d_p.add(buf,alpha=momentum)else:d_p=buf# LARSp_norm=p.data.pow(2).sum().sqrt()update_norm=d_p.pow(2).sum().sqrt()# Compute the local LRifp_norm==0orupdate_norm==0:local_lr=1else:local_lr=p_norm/update_normp.data.add_(d_p,alpha=-group['lr']*local_lr)returnloss