holocron.optim¶
To use holocron.optim
you have to construct an optimizer object, that will hold
the current state and will update the parameters based on the computed gradients.
Optimizers¶
Implementations of recent parameter optimizer for Pytorch modules.
- class holocron.optim.LARS(params: Iterable[Parameter], lr: float = 0.001, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, scale_clip: Tuple[float, float] | None = None)[source]¶
Implements the LARS optimizer from “Large batch training of convolutional networks”.
The estimation of global and local learning rates is described as follows, \(\forall t \geq 1\):
\[\begin{split}\alpha_t \leftarrow \alpha (1 - t / T)^2 \\ \gamma_t \leftarrow \frac{\lVert \theta_t \rVert}{\lVert g_t \rVert + \lambda \lVert \theta_t \rVert}\end{split}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(g_t\) is the gradient of \(\theta_t\), \(T\) is the total number of steps, \(\alpha\) is the learning rate \(\lambda \geq 0\) is the weight decay.
Then we estimate the momentum using:
\[v_t \leftarrow m v_{t-1} + \alpha_t \gamma_t (g_t + \lambda \theta_t)\]where \(m\) is the momentum and \(v_0 = 0\).
And finally the update step is performed using the following rule:
\[\theta_t \leftarrow \theta_{t-1} - v_t\]- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
momentum (float, optional) – momentum factor (default: 0)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
dampening (float, optional) – dampening for momentum (default: 0)
nesterov (bool, optional) – enables Nesterov momentum (default: False)
scale_clip (tuple, optional) – the lower and upper bounds for the weight norm in local LR of LARS
- class holocron.optim.LAMB(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, scale_clip: Tuple[float, float] | None = None)[source]¶
Implements the Lamb optimizer from “Large batch optimization for deep learning: training BERT in 76 minutes”.
The estimation of momentums is described as follows, \(\forall t \geq 1\):
\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\end{split}\]where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = 0,\ v_0 = 0\).
Then we correct their biases using:
\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t}\end{split}\]And finally the update step is performed using the following rule:
\[\begin{split}r_t \leftarrow \frac{\hat{m_t}}{\sqrt{\hat{v_t}} + \epsilon} \\ \theta_t \leftarrow \theta_{t-1} - \alpha \phi(\lVert \theta_t \rVert) \frac{r_t + \lambda \theta_t}{\lVert r_t + \theta_t \rVert}\end{split}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\phi\) is a clipping function, \(\alpha\) is the learning rate, \(\lambda \geq 0\) is the weight decay, \(\epsilon > 0\).
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float], optional) – beta coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
scale_clip (tuple, optional) – the lower and upper bounds for the weight norm in local LR of LARS
- class holocron.optim.RaLars(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, force_adaptive_momentum: bool = False, scale_clip: Tuple[float, float] | None = None)[source]¶
Implements the RAdam optimizer from “On the variance of the Adaptive Learning Rate and Beyond” with optional Layer-wise adaptive Scaling from “Large Batch Training of Convolutional Networks”
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
force_adaptive_momentum (float, optional) – use adaptive momentum if variance is not tractable (default: False)
scale_clip (float, optional) – the maximal upper bound for the scale factor of LARS
- class holocron.optim.TAdam(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, dof: float | None = None)[source]¶
Implements the TAdam optimizer from “TAdam: A Robust Stochastic Gradient Optimizer”.
The estimation of momentums is described as follows, \(\forall t \geq 1\):
\[\begin{split}w_t \leftarrow (\nu + d) \Big(\nu + \sum\limits_{j} \frac{(g_t^j - m_{t-1}^j)^2}{v_{t-1} + \epsilon} \Big)^{-1} \\ m_t \leftarrow \frac{W_{t-1}}{W_{t-1} + w_t} m_{t-1} + \frac{w_t}{W_{t-1} + w_t} g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1})\end{split}\]where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = 0,\ v_0 = 0,\ W_0 = \frac{\beta_1}{1 - \beta_1}\); \(\nu\) is the degrees of freedom and \(d\) if the number of dimensions of the parameter gradient.
Then we correct their biases using:
\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t}\end{split}\]And finally the update step is performed using the following rule:
\[\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{v_t}} + \epsilon}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\epsilon > 0\).
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
dof (int, optional) – degrees of freedom
- class holocron.optim.AdaBelief(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr: float | Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, *, foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: bool | None = None)[source]¶
Implements the AdaBelief optimizer from “AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients”.
The estimation of momentums is described as follows, \(\forall t \geq 1\):
\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\end{split}\]where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = 0,\ s_0 = 0\), \(\epsilon > 0\).
Then we correct their biases using:
\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t}\end{split}\]And finally the update step is performed using the following rule:
\[\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{s_t}} + \epsilon}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\epsilon > 0\).
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)
- class holocron.optim.AdamP(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, delta: float = 0.1)[source]¶
Implements the AdamP optimizer from “AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights”.
The estimation of momentums is described as follows, \(\forall t \geq 1\):
\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\end{split}\]where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = g_0,\ v_0 = 0\).
Then we correct their biases using:
\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t}\end{split}\]And finally the update step is performed using the following rule:
\[\begin{split}p_t \leftarrow \frac{\hat{m_t}}{\sqrt{\hat{n_t} + \epsilon}} \\ q_t \leftarrow \begin{cases} \prod_{\theta_t}(p_t) & if\ cos(\theta_t, g_t) < \delta / \sqrt{dim(\theta)}\\ p_t & \text{otherwise}\\ \end{cases} \\ \theta_t \leftarrow \theta_{t-1} - \alpha q_t\end{split}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\prod_{\theta_t}(p_t)\) is the projection of \(p_t\) onto the tangent space of \(\theta_t\), \(cos(\theta_t, g_t)\) is the cosine similarity between \(\theta_t\) and \(g_t\), \(\alpha\) is the learning rate, \(\delta > 0\), \(\epsilon > 0\).
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)
delta (float, optional) – delta threshold for projection (default: False)
- class holocron.optim.Adan(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float, float] = (0.98, 0.92, 0.99), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False)[source]¶
Implements the Adan optimizer from “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models”.
The estimation of momentums is described as follows, \(\forall t \geq 1\):
\[\begin{split}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1}) \\ n_t \leftarrow \beta_3 n_{t-1} + (1 - \beta_3) [g_t + \beta_2 (g_t - g_{t - 1})]^2\end{split}\]where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2, \beta_3 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = g_0,\ v_0 = 0,\ n_0 = g_0^2\).
Then we correct their biases using:
\[\begin{split}\hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t} \\ \hat{n_t} \leftarrow \frac{n_t}{1 - \beta_3^t}\end{split}\]And finally the update step is performed using the following rule:
\[\begin{split}p_t \leftarrow \frac{\hat{m_t} + (1 - \beta_2) \hat{v_t}}{\sqrt{\hat{n_t} + \epsilon}} \\ \theta_t \leftarrow \frac{\theta_{t-1} - \alpha p_t}{1 + \lambda \alpha}\end{split}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\lambda \geq 0\) is the weight decay, \(\epsilon > 0\).
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float, float], optional) – coefficients used for running averages
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)
- class holocron.optim.AdEMAMix(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), alpha: float = 5.0, eps: float = 1e-08, weight_decay: float = 0.0)[source]¶
Implements the AdEMAMix optimizer from “The AdEMAMix Optimizer: Better, Faster, Older”.
The estimation of momentums is described as follows, \(\forall t \geq 1\):
\[\begin{split}m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\ m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\ s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\end{split}\]where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2, \beta_3 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0\), \(\epsilon > 0\).
Then we correct their biases using:
\[\begin{split}\hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\ \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t}\end{split}\]And finally the update step is performed using the following rule:
\[\theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon}\]where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\eta\) is the learning rate, \(\alpha > 0\) \(\epsilon > 0\).
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional) – learning rate
betas (Tuple[float, float, float], optional) – coefficients used for running averages (default: (0.9, 0.999, 0.9999))
alpha (float, optional) – the exponential decay rate of the second moment estimates (default: 5.0)
eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)
Optimizer wrappers¶
holocron.optim
also implements optimizer wrappers.
A base optimizer should always be passed to the wrapper; e.g., you should write your code this way:
>>> optimizer = ...
>>> optimizer = wrapper(optimizer)
- class holocron.optim.wrapper.Lookahead(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)[source]¶
Implements the Lookahead optimizer wrapper from “Lookahead Optimizer: k steps forward, 1 step back”.
>>> from torch.optim import AdamW >>> from holocron.optim.wrapper import Lookahead >>> model = ... >>> opt = AdamW(model.parameters(), lr=3e-4) >>> opt_wrapper = Lookahead(opt)
- Parameters:
base_optimizer (torch.optim.optimizer.Optimizer) – base parameter optimizer
sync_rate (int, optional) – rate of weight synchronization
sync_period (int, optional) – number of step performed on fast weights before weight synchronization
- class holocron.optim.wrapper.Scout(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)[source]¶
Implements a new optimizer wrapper based on “Lookahead Optimizer: k steps forward, 1 step back”.
- Example::
>>> from torch.optim import AdamW >>> from holocron.optim.wrapper import Scout >>> model = ... >>> opt = AdamW(model.parameters(), lr=3e-4) >>> opt_wrapper = Scout(opt)
- Parameters:
base_optimizer (torch.optim.optimizer.Optimizer) – base parameter optimizer
sync_rate (float, optional) – rate of weight synchronization
sync_period (int, optional) – number of step performed on fast weights before weight synchronization