holocron.optim#

To use holocron.optim you have to construct an optimizer object, that will hold the current state and will update the parameters based on the computed gradients.

Optimizers#

Implementations of recent parameter optimizer for Pytorch modules.

class holocron.optim.LARS(params: Iterable[Parameter], lr: float = 0.001, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, scale_clip: Tuple[float, float] | None = None)[source]#

Implements the LARS optimizer from “Large batch training of convolutional networks”.

The estimation of global and local learning rates is described as follows, \(\forall t \geq 1\):

\[\begin{split}\alpha_t \leftarrow \alpha (1 - t / T)^2 \\ \gamma_t \leftarrow \frac{\lVert \theta_t \rVert}{\lVert g_t \rVert + \lambda \lVert \theta_t \rVert}\end{split}\]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(g_t\) is the gradient of \(\theta_t\), \(T\) is the total number of steps, \(\alpha\) is the learning rate \(\lambda \geq 0\) is the weight decay.

Then we estimate the momentum using:

\[v_t \leftarrow m v_{t-1} + \alpha_t \gamma_t (g_t + \lambda \theta_t)\]

where \(m\) is the momentum and \(v_0 = 0\).

And finally the update step is performed using the following rule:

\[\theta_t \leftarrow \theta_{t-1} - v_t\]
Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

  • scale_clip (tuple, optional) – the lower and upper bounds for the weight norm in local LR of LARS

class holocron.optim.LAMB(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, scale_clip: Tuple[float, float] | None = None)[source]#

Implements the Lamb optimizer from “Large batch optimization for deep learning: training BERT in 76 minutes”.

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\end{split}\]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = 0,\ v_0 = 0\).

Then we correct their biases using:

\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t}\end{split}\]

And finally the update step is performed using the following rule:

\[\begin{split}r_t \leftarrow \frac{\hat{m_t}}{\sqrt{\hat{v_t}} + \epsilon} \\ \theta_t \leftarrow \theta_{t-1} - \alpha \phi(\lVert \theta_t \rVert) \frac{r_t + \lambda \theta_t}{\lVert r_t + \theta_t \rVert}\end{split}\]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\phi\) is a clipping function, \(\alpha\) is the learning rate, \(\lambda \geq 0\) is the weight decay, \(\epsilon > 0\).

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – beta coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • scale_clip (tuple, optional) – the lower and upper bounds for the weight norm in local LR of LARS

class holocron.optim.RaLars(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, force_adaptive_momentum: bool = False, scale_clip: Tuple[float, float] | None = None)[source]#

Implements the RAdam optimizer from “On the variance of the Adaptive Learning Rate and Beyond” with optional Layer-wise adaptive Scaling from “Large Batch Training of Convolutional Networks”

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • force_adaptive_momentum (float, optional) – use adaptive momentum if variance is not tractable (default: False)

  • scale_clip (float, optional) – the maximal upper bound for the scale factor of LARS

class holocron.optim.TAdam(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, dof: float | None = None)[source]#

Implements the TAdam optimizer from “TAdam: A Robust Stochastic Gradient Optimizer”.

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[\begin{split}w_t \leftarrow (\nu + d) \Big(\nu + \sum\limits_{j} \frac{(g_t^j - m_{t-1}^j)^2}{v_{t-1} + \epsilon} \Big)^{-1} \\ m_t \leftarrow \frac{W_{t-1}}{W_{t-1} + w_t} m_{t-1} + \frac{w_t}{W_{t-1} + w_t} g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1})\end{split}\]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = 0,\ v_0 = 0,\ W_0 = \frac{\beta_1}{1 - \beta_1}\); \(\nu\) is the degrees of freedom and \(d\) if the number of dimensions of the parameter gradient.

Then we correct their biases using:

\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t}\end{split}\]

And finally the update step is performed using the following rule:

\[\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{v_t}} + \epsilon}\]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\epsilon > 0\).

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dof (int, optional) – degrees of freedom

class holocron.optim.AdaBelief(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr: float | Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, *, foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: bool | None = None)[source]#

Implements the AdaBelief optimizer from “AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients”.

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\end{split}\]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = 0,\ s_0 = 0\), \(\epsilon > 0\).

Then we correct their biases using:

\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t}\end{split}\]

And finally the update step is performed using the following rule:

\[\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{s_t}} + \epsilon}\]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\epsilon > 0\).

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

class holocron.optim.AdamP(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, delta: float = 0.1)[source]#

Implements the AdamP optimizer from “AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights”.

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\end{split}\]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = g_0,\ v_0 = 0\).

Then we correct their biases using:

\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t}\end{split}\]

And finally the update step is performed using the following rule:

\[\begin{split}p_t \leftarrow \frac{\hat{m_t}}{\sqrt{\hat{n_t} + \epsilon}} \\ q_t \leftarrow \begin{cases} \prod_{\theta_t}(p_t) & if\ cos(\theta_t, g_t) < \delta / \sqrt{dim(\theta)}\\ p_t & \text{otherwise}\\ \end{cases} \\ \theta_t \leftarrow \theta_{t-1} - \alpha q_t\end{split}\]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\prod_{\theta_t}(p_t)\) is the projection of \(p_t\) onto the tangent space of \(\theta_t\), \(cos(\theta_t, g_t)\) is the cosine similarity between \(\theta_t\) and \(g_t\), \(\alpha\) is the learning rate, \(\delta > 0\), \(\epsilon > 0\).

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

  • delta (float, optional) – delta threshold for projection (default: False)

class holocron.optim.Adan(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float, float] = (0.98, 0.92, 0.99), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False)[source]#

Implements the Adan optimizer from “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models”.

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1}) \\ n_t \leftarrow \beta_3 n_{t-1} + (1 - \beta_3) [g_t + \beta_2 (g_t - g_{t - 1})]^2\end{split}\]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2, \beta_3 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = g_0,\ v_0 = 0,\ n_0 = g_0^2\).

Then we correct their biases using:

\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t} \\ \hat{n_t} \leftarrow \frac{n_t}{1 - \beta_3^t}\end{split}\]

And finally the update step is performed using the following rule:

\[\begin{split}p_t \leftarrow \frac{\hat{m_t} + (1 - \beta_2) \hat{v_t}}{\sqrt{\hat{n_t} + \epsilon}} \\ \theta_t \leftarrow \frac{\theta_{t-1} - \alpha p_t}{1 + \lambda \alpha}\end{split}\]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\lambda \geq 0\) is the weight decay, \(\epsilon > 0\).

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float, float], optional) – coefficients used for running averages

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

Optimizer wrappers#

holocron.optim also implements optimizer wrappers.

A base optimizer should always be passed to the wrapper; e.g., you should write your code this way:

>>> optimizer = ...
>>> optimizer = wrapper(optimizer)
class holocron.optim.wrapper.Lookahead(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)[source]#

Implements the Lookahead optimizer wrapper from “Lookahead Optimizer: k steps forward, 1 step back”.

>>> from torch.optim import AdamW
>>> from holocron.optim.wrapper import Lookahead
>>> model = ...
>>> opt = AdamW(model.parameters(), lr=3e-4)
>>> opt_wrapper = Lookahead(opt)
Parameters:
  • base_optimizer (torch.optim.optimizer.Optimizer) – base parameter optimizer

  • sync_rate (int, optional) – rate of weight synchronization

  • sync_period (int, optional) – number of step performed on fast weights before weight synchronization

class holocron.optim.wrapper.Scout(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)[source]#

Implements a new optimizer wrapper based on “Lookahead Optimizer: k steps forward, 1 step back”.

Example::
>>> from torch.optim import AdamW
>>> from holocron.optim.wrapper import Scout
>>> model = ...
>>> opt = AdamW(model.parameters(), lr=3e-4)
>>> opt_wrapper = Scout(opt)
Parameters:
  • base_optimizer (torch.optim.optimizer.Optimizer) – base parameter optimizer

  • sync_rate (float, optional) – rate of weight synchronization

  • sync_period (int, optional) – number of step performed on fast weights before weight synchronization