holocron.optim

To use holocron.optim you have to construct an optimizer object, that will hold the current state and will update the parameters based on the computed gradients.

Optimizers

Implementations of recent parameter optimizer for Pytorch modules.

class holocron.optim.LARS(params: Iterable[Parameter], lr: float = 0.001, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, scale_clip: Tuple[float, float] | None = None)[source]

Implements the LARS optimizer from “Large batch training of convolutional networks”.

The estimation of global and local learning rates is described as follows, t1:

αtα(1t/T)2γtθtgt+λθt

where θt is the parameter value at step t (θ0 being the initialization value), gt is the gradient of θt, T is the total number of steps, α is the learning rate λ0 is the weight decay.

Then we estimate the momentum using:

vtmvt1+αtγt(gt+λθt)

where m is the momentum and v0=0.

And finally the update step is performed using the following rule:

θtθt1vt
Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

  • scale_clip (tuple, optional) – the lower and upper bounds for the weight norm in local LR of LARS

class holocron.optim.LAMB(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, scale_clip: Tuple[float, float] | None = None)[source]

Implements the Lamb optimizer from “Large batch optimization for deep learning: training BERT in 76 minutes”.

The estimation of momentums is described as follows, t1:

mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2

where gt is the gradient of θt, β1,β2[0,1]2 are the exponential average smoothing coefficients, m0=0, v0=0.

Then we correct their biases using:

mt^mt1β1tvt^vt1β2t

And finally the update step is performed using the following rule:

rtmt^vt^+ϵθtθt1αϕ(θt)rt+λθtrt+θt

where θt is the parameter value at step t (θ0 being the initialization value), ϕ is a clipping function, α is the learning rate, λ0 is the weight decay, ϵ>0.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – beta coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • scale_clip (tuple, optional) – the lower and upper bounds for the weight norm in local LR of LARS

class holocron.optim.RaLars(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, force_adaptive_momentum: bool = False, scale_clip: Tuple[float, float] | None = None)[source]

Implements the RAdam optimizer from “On the variance of the Adaptive Learning Rate and Beyond” with optional Layer-wise adaptive Scaling from “Large Batch Training of Convolutional Networks”

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • force_adaptive_momentum (float, optional) – use adaptive momentum if variance is not tractable (default: False)

  • scale_clip (float, optional) – the maximal upper bound for the scale factor of LARS

class holocron.optim.TAdam(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, dof: float | None = None)[source]

Implements the TAdam optimizer from “TAdam: A Robust Stochastic Gradient Optimizer”.

The estimation of momentums is described as follows, t1:

wt(ν+d)(ν+j(gtjmt1j)2vt1+ϵ)1mtWt1Wt1+wtmt1+wtWt1+wtgtvtβ2vt1+(1β2)(gtgt1)

where gt is the gradient of θt, β1,β2[0,1]2 are the exponential average smoothing coefficients, m0=0, v0=0, W0=β11β1; ν is the degrees of freedom and d if the number of dimensions of the parameter gradient.

Then we correct their biases using:

mt^mt1β1tvt^vt1β2t

And finally the update step is performed using the following rule:

θtθt1αmt^vt^+ϵ

where θt is the parameter value at step t (θ0 being the initialization value), α is the learning rate, ϵ>0.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dof (int, optional) – degrees of freedom

class holocron.optim.AdaBelief(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr: float | Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, *, foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: bool | None = None)[source]

Implements the AdaBelief optimizer from “AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients”.

The estimation of momentums is described as follows, t1:

mtβ1mt1+(1β1)gtstβ2st1+(1β2)(gtmt)2+ϵ

where gt is the gradient of θt, β1,β2[0,1]2 are the exponential average smoothing coefficients, m0=0, s0=0, ϵ>0.

Then we correct their biases using:

mt^mt1β1tst^st1β2t

And finally the update step is performed using the following rule:

θtθt1αmt^st^+ϵ

where θt is the parameter value at step t (θ0 being the initialization value), α is the learning rate, ϵ>0.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

class holocron.optim.AdamP(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, delta: float = 0.1)[source]

Implements the AdamP optimizer from “AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights”.

The estimation of momentums is described as follows, t1:

mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2

where gt is the gradient of θt, β1,β2[0,1]2 are the exponential average smoothing coefficients, m0=g0, v0=0.

Then we correct their biases using:

mt^mt1β1tvt^vt1β2t

And finally the update step is performed using the following rule:

ptmt^nt^+ϵqt{θt(pt)if cos(θt,gt)<δ/dim(θ)ptotherwiseθtθt1αqt

where θt is the parameter value at step t (θ0 being the initialization value), θt(pt) is the projection of pt onto the tangent space of θt, cos(θt,gt) is the cosine similarity between θt and gt, α is the learning rate, δ>0, ϵ>0.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float], optional) – coefficients used for running averages (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

  • delta (float, optional) – delta threshold for projection (default: False)

class holocron.optim.Adan(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float, float] = (0.98, 0.92, 0.99), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False)[source]

Implements the Adan optimizer from “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models”.

The estimation of momentums is described as follows, t1:

mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)(gtgt1)ntβ3nt1+(1β3)[gt+β2(gtgt1)]2

where gt is the gradient of θt, β1,β2,β3[0,1]3 are the exponential average smoothing coefficients, m0=g0, v0=0, n0=g02.

Then we correct their biases using:

mt^mt1β1tvt^vt1β2tnt^nt1β3t

And finally the update step is performed using the following rule:

ptmt^+(1β2)vt^nt^+ϵθtθt1αpt1+λα

where θt is the parameter value at step t (θ0 being the initialization value), α is the learning rate, λ0 is the weight decay, ϵ>0.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float, float], optional) – coefficients used for running averages

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

class holocron.optim.AdEMAMix(params: Iterable[Parameter], lr: float = 0.001, betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), alpha: float = 5.0, eps: float = 1e-08, weight_decay: float = 0.0)[source]

Implements the AdEMAMix optimizer from “The AdEMAMix Optimizer: Better, Faster, Older”.

The estimation of momentums is described as follows, t1:

m1,tβ1m1,t1+(1β1)gtm2,tβ3m2,t1+(1β3)gtstβ2st1+(1β2)(gtmt)2+ϵ

where gt is the gradient of θt, β1,β2,β3[0,1]3 are the exponential average smoothing coefficients, m1,0=0, m2,0=0, s0=0, ϵ>0.

Then we correct their biases using:

m1,t^m1,t1β1tst^st1β2t

And finally the update step is performed using the following rule:

θtθt1ηm1,t^+αm2,tst^+ϵ

where θt is the parameter value at step t (θ0 being the initialization value), η is the learning rate, α>0 ϵ>0.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate

  • betas (Tuple[float, float, float], optional) – coefficients used for running averages (default: (0.9, 0.999, 0.9999))

  • alpha (float, optional) – the exponential decay rate of the second moment estimates (default: 5.0)

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (bool, optional) – whether to use the AMSGrad variant (default: False)

Optimizer wrappers

holocron.optim also implements optimizer wrappers.

A base optimizer should always be passed to the wrapper; e.g., you should write your code this way:

>>> optimizer = ...
>>> optimizer = wrapper(optimizer)
class holocron.optim.wrapper.Lookahead(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)[source]

Implements the Lookahead optimizer wrapper from “Lookahead Optimizer: k steps forward, 1 step back”.

>>> from torch.optim import AdamW
>>> from holocron.optim.wrapper import Lookahead
>>> model = ...
>>> opt = AdamW(model.parameters(), lr=3e-4)
>>> opt_wrapper = Lookahead(opt)
Parameters:
  • base_optimizer (torch.optim.optimizer.Optimizer) – base parameter optimizer

  • sync_rate (int, optional) – rate of weight synchronization

  • sync_period (int, optional) – number of step performed on fast weights before weight synchronization

class holocron.optim.wrapper.Scout(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)[source]

Implements a new optimizer wrapper based on “Lookahead Optimizer: k steps forward, 1 step back”.

Example::
>>> from torch.optim import AdamW
>>> from holocron.optim.wrapper import Scout
>>> model = ...
>>> opt = AdamW(model.parameters(), lr=3e-4)
>>> opt_wrapper = Scout(opt)
Parameters:
  • base_optimizer (torch.optim.optimizer.Optimizer) – base parameter optimizer

  • sync_rate (float, optional) – rate of weight synchronization

  • sync_period (int, optional) – number of step performed on fast weights before weight synchronization