# Copyright (C) 2019-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.importmathfromtypingimportCallable,List,OptionalimporttorchfromtorchimportTensorfromtorch.optimimportAdam__all__=["AdaBelief","adabelief"]
[docs]classAdaBelief(Adam):r"""Implements the AdaBelief optimizer from `"AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients" <https://arxiv.org/pdf/2010.07468.pdf>`_. The estimation of momentums is described as follows, :math:`\forall t \geq 1`: .. math:: m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon where :math:`g_t` is the gradient of :math:`\theta_t`, :math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients, :math:`m_0 = 0,\ s_0 = 0`, :math:`\epsilon > 0`. Then we correct their biases using: .. math:: \hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t} And finally the update step is performed using the following rule: .. math:: \theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{s_t}} + \epsilon} where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value), :math:`\alpha` is the learning rate, :math:`\epsilon > 0`. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate betas (Tuple[float, float], optional): coefficients used for running averages (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (bool, optional): whether to use the AMSGrad variant (default: False) """@torch.no_grad()defstep(self,closure:Optional[Callable[[],float]]=None)->Optional[float]:# type: ignore[override]"""Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad=[]grads=[]exp_avgs=[]exp_avg_sqs=[]max_exp_avg_sqs=[]state_steps=[]forpingroup["params"]:ifp.gradisnotNone:params_with_grad.append(p)ifp.grad.is_sparse:raiseRuntimeError(f"{self.__class__.__name__} does not support sparse gradients")grads.append(p.grad)state=self.state[p]# Lazy state initializationiflen(state)==0:state["step"]=0# Exponential moving average of gradient valuesstate["exp_avg"]=torch.zeros_like(p,memory_format=torch.preserve_format)# Exponential moving average of squared gradient valuesstate["exp_avg_sq"]=torch.zeros_like(p,memory_format=torch.preserve_format)ifgroup["amsgrad"]:# Maintains max of all exp. moving avg. of sq. grad. valuesstate["max_exp_avg_sq"]=torch.zeros_like(p,memory_format=torch.preserve_format)exp_avgs.append(state["exp_avg"])exp_avg_sqs.append(state["exp_avg_sq"])ifgroup["amsgrad"]:max_exp_avg_sqs.append(state["max_exp_avg_sq"])# update the steps for each param group updatestate["step"]+=1# record the step after step updatestate_steps.append(state["step"])beta1,beta2=group["betas"]adabelief(params_with_grad,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps,group["amsgrad"],beta1,beta2,group["lr"],group["weight_decay"],group["eps"],)returnloss
defadabelief(params:List[Tensor],grads:List[Tensor],exp_avgs:List[Tensor],exp_avg_sqs:List[Tensor],max_exp_avg_sqs:List[Tensor],state_steps:List[int],amsgrad:bool,beta1:float,beta2:float,lr:float,weight_decay:float,eps:float,)->None:r"""Functional API that performs AdaBelief algorithm computation. See :class:`~holocron.optim.AdaBelief` for details. """fori,paraminenumerate(params):grad=grads[i]exp_avg=exp_avgs[i]exp_avg_sq=exp_avg_sqs[i]step=state_steps[i]ifamsgrad:max_exp_avg_sq=max_exp_avg_sqs[i]bias_correction1=1-beta1**stepbias_correction2=1-beta2**stepifweight_decay!=0:grad=grad.add(param,alpha=weight_decay)# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(grad,alpha=1-beta1)grad_residual=grad-exp_avgexp_avg_sq.mul_(beta2).addcmul_(grad_residual,grad_residual,value=1-beta2)ifamsgrad:# Maintains the maximum of all 2nd moment running avg. till nowtorch.maximum(max_exp_avg_sq,exp_avg_sq,out=max_exp_avg_sq)# Use the max. for normalizing running avg. of gradientdenom=(max_exp_avg_sq.sqrt()/math.sqrt(bias_correction2)).add_(eps)else:denom=(exp_avg_sq.sqrt()/math.sqrt(bias_correction2)).add_(eps)step_size=lr/bias_correction1param.addcdiv_(exp_avg,denom,value=-step_size)