Skip to content

holocron.optim

To use holocron.optim you have to construct an optimizer object, that will hold the current state and will update the parameters based on the computed gradients.

Optimizers

Implementations of recent parameter optimizer for Pytorch modules.

LARS

LARS(params: Iterable[Parameter], lr: float = 0.001, momentum: float = 0.0, dampening: float = 0.0, weight_decay: float = 0.0, nesterov: bool = False, scale_clip: tuple[float, float] | None = None)

Bases: Optimizer

Implements the LARS optimizer from "Large batch training of convolutional networks".

The estimation of global and local learning rates is described as follows, \(\forall t \geq 1\):

\[ \alpha_t \leftarrow \alpha (1 - t / T)^2 \\ \gamma_t \leftarrow \frac{\lVert \theta_t \rVert}{\lVert g_t \rVert + \lambda \lVert \theta_t \rVert} \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(g_t\) is the gradient of \(\theta_t\), \(T\) is the total number of steps, \(\alpha\) is the learning rate \(\lambda \geq 0\) is the weight decay.

Then we estimate the momentum using:

\[ v_t \leftarrow m v_{t-1} + \alpha_t \gamma_t (g_t + \lambda \theta_t) \]

where \(m\) is the momentum and \(v_0 = 0\).

And finally the update step is performed using the following rule:

\[ \theta_t \leftarrow \theta_{t-1} - v_t \]
PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

momentum

momentum factor

TYPE: float DEFAULT: 0.0

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

dampening

dampening for momentum

TYPE: float DEFAULT: 0.0

nesterov

enables Nesterov momentum

TYPE: bool DEFAULT: False

scale_clip

the lower and upper bounds for the weight norm in local LR of LARS

TYPE: tuple[float, float] | None DEFAULT: None

Source code in holocron/optim/lars.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    momentum: float = 0.0,
    dampening: float = 0.0,
    weight_decay: float = 0.0,
    nesterov: bool = False,
    scale_clip: tuple[float, float] | None = None,
) -> None:
    if not isinstance(lr, float) or lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if momentum < 0.0:
        raise ValueError(f"Invalid momentum value: {momentum}")
    if weight_decay < 0.0:
        raise ValueError(f"Invalid weight_decay value: {weight_decay}")

    defaults = {
        "lr": lr,
        "momentum": momentum,
        "dampening": dampening,
        "weight_decay": weight_decay,
        "nesterov": nesterov,
    }
    if nesterov and (momentum <= 0 or dampening != 0):
        raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super().__init__(params, defaults)
    # LARS arguments
    self.scale_clip = scale_clip
    if self.scale_clip is None:
        self.scale_clip = (0.0, 10.0)

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

Source code in holocron/optim/lars.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        weight_decay = group["weight_decay"]
        momentum = group["momentum"]
        dampening = group["dampening"]
        nesterov = group["nesterov"]

        for p in group["params"]:
            if p.grad is None:
                continue
            d_p = p.grad.data

            # LARS
            p_norm = torch.norm(p.data)
            denom = torch.norm(d_p)
            if weight_decay != 0:
                d_p.add_(p.data, alpha=weight_decay)
                denom.add_(p_norm, alpha=weight_decay)
            # Compute the local LR
            local_lr = 1 if p_norm == 0 or denom == 0 else p_norm / denom

            if momentum == 0:
                p.data.add_(d_p, alpha=-group["lr"] * local_lr)
            else:
                param_state = self.state[p]
                if "momentum_buffer" not in param_state:
                    momentum_buffer = param_state["momentum_buffer"] = torch.clone(d_p).detach()
                else:
                    momentum_buffer = param_state["momentum_buffer"]
                    momentum_buffer.mul_(momentum).add_(d_p, alpha=1 - dampening)
                d_p = d_p.add(momentum_buffer, alpha=momentum) if nesterov else momentum_buffer
                p.data.add_(d_p, alpha=-group["lr"] * local_lr)
                self.state[p]["momentum_buffer"] = momentum_buffer

    return loss

LAMB

LAMB(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, scale_clip: tuple[float, float] | None = None)

Bases: Optimizer

Implements the Lamb optimizer from "Large batch optimization for deep learning: training BERT in 76 minutes".

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[ m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = 0,\ v_0 = 0\).

Then we correct their biases using:

\[ \hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t} \]

And finally the update step is performed using the following rule:

\[ r_t \leftarrow \frac{\hat{m_t}}{\sqrt{\hat{v_t}} + \epsilon} \\ \theta_t \leftarrow \theta_{t-1} - \alpha \phi(\lVert \theta_t \rVert) \frac{r_t + \lambda \theta_t}{\lVert r_t + \theta_t \rVert} \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\phi\) is a clipping function, \(\alpha\) is the learning rate, \(\lambda \geq 0\) is the weight decay, \(\epsilon > 0\).

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

betas

beta coefficients used for running averages

TYPE: tuple[float, float] DEFAULT: (0.9, 0.999)

eps

term added to the denominator to improve numerical stability

TYPE: float DEFAULT: 1e-08

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

scale_clip

the lower and upper bounds for the weight norm in local LR of LARS

TYPE: tuple[float, float] | None DEFAULT: None

Source code in holocron/optim/lamb.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    scale_clip: tuple[float, float] | None = None,
) -> None:
    if lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if eps < 0.0:
        raise ValueError(f"Invalid epsilon value: {eps}")
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
    defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}
    super().__init__(params, defaults)
    # LARS arguments
    self.scale_clip = scale_clip
    if self.scale_clip is None:
        self.scale_clip = (0.0, 10.0)

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/lamb.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad.data
            if grad.is_sparse:
                raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = 0
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p.data)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(p.data)

            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
            beta1, beta2 = group["betas"]

            state["step"] += 1

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            # Gradient term correction
            update = torch.zeros_like(p.data)
            denom = exp_avg_sq.sqrt().add_(group["eps"])
            update.addcdiv_(exp_avg, denom)

            # Weight decay
            if group["weight_decay"] != 0:
                update.add_(p.data, alpha=group["weight_decay"])

            # LARS
            p_norm = p.data.pow(2).sum().sqrt()
            update_norm = update.pow(2).sum().sqrt()
            phi_p = p_norm.clamp(*self.scale_clip)
            # Compute the local LR
            local_lr = 1 if phi_p == 0 or update_norm == 0 else phi_p / update_norm

            state["local_lr"] = local_lr

            p.data.add_(update, alpha=-group["lr"] * local_lr)

    return loss

RaLars

RaLars(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, force_adaptive_momentum: bool = False, scale_clip: tuple[float, float] | None = None)

Bases: Optimizer

Implements the RAdam optimizer from "On the variance of the Adaptive Learning Rate and Beyond" with optional Layer-wise adaptive Scaling from "Large Batch Training of Convolutional Networks"

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

betas

coefficients used for running averages

TYPE: tuple[float, float] DEFAULT: (0.9, 0.999)

eps

term added to the denominator to improve numerical stability

TYPE: float DEFAULT: 1e-08

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

force_adaptive_momentum

use adaptive momentum if variance is not tractable

TYPE: bool DEFAULT: False

scale_clip

the maximal upper bound for the scale factor of LARS

TYPE: tuple[float, float] | None DEFAULT: None

Source code in holocron/optim/ralars.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    force_adaptive_momentum: bool = False,
    scale_clip: tuple[float, float] | None = None,
) -> None:
    if lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if eps < 0.0:
        raise ValueError(f"Invalid epsilon value: {eps}")
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
    defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}
    super().__init__(params, defaults)
    # RAdam tweaks
    self.force_adaptive_momentum = force_adaptive_momentum
    # LARS arguments
    self.scale_clip = scale_clip
    if self.scale_clip is None:
        self.scale_clip = (0, 10)

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/ralars.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        # Get group-shared variables
        beta1, beta2 = group["betas"]
        # Compute max length of SMA on first step
        if not isinstance(group.get("sma_inf"), float):
            group["sma_inf"] = 2 / (1 - beta2) - 1
        sma_inf = group["sma_inf"]

        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad.data
            if grad.is_sparse:
                raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = 0
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p.data)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(p.data)

            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

            state["step"] += 1

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            # Bias correction
            bias_correction1 = 1 - beta1 ** state["step"]
            bias_correction2 = 1 - beta2 ** state["step"]

            # Compute length of SMA
            sma_t = sma_inf - 2 * state["step"] * (1 - bias_correction2) / bias_correction2

            update = torch.zeros_like(p.data)
            if sma_t > 4:
                # Variance rectification term
                r_t = math.sqrt((sma_t - 4) * (sma_t - 2) * sma_inf / ((sma_inf - 4) * (sma_inf - 2) * sma_t))
                # Adaptive momentum
                update.addcdiv_(
                    exp_avg / bias_correction1, (exp_avg_sq / bias_correction2).sqrt().add_(group["eps"]), value=r_t
                )
            elif self.force_adaptive_momentum:
                # Adaptive momentum without variance rectification (Adam)
                update.addcdiv_(
                    exp_avg / bias_correction1, (exp_avg_sq / bias_correction2).sqrt().add_(group["eps"])
                )
            else:
                # Unadapted momentum
                update.add_(exp_avg / bias_correction1)

            # Weight decay
            if group["weight_decay"] != 0:
                update.add_(p.data, alpha=group["weight_decay"])

            # LARS
            p_norm = p.data.pow(2).sum().sqrt()
            update_norm = update.pow(2).sum().sqrt()
            phi_p = p_norm.clamp(*self.scale_clip)
            # Compute the local LR
            local_lr = 1 if phi_p == 0 or update_norm == 0 else phi_p / update_norm

            state["local_lr"] = local_lr

            p.data.add_(update, alpha=-group["lr"] * local_lr)

    return loss

TAdam

TAdam(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, dof: float | None = None)

Bases: Optimizer

Implements the TAdam optimizer from "TAdam: A Robust Stochastic Gradient Optimizer".

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[ w_t \leftarrow (\nu + d) \Big(\nu + \sum\limits_{j} \frac{(g_t^j - m_{t-1}^j)^2}{v_{t-1} + \epsilon} \Big)^{-1} \\ m_t \leftarrow \frac{W_{t-1}}{W_{t-1} + w_t} m_{t-1} + \frac{w_t}{W_{t-1} + w_t} g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1}) \]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = 0,\ v_0 = 0,\ W_0 = \frac{\beta_1}{1 - \beta_1}\); \(\nu\) is the degrees of freedom and \(d\) if the number of dimensions of the parameter gradient.

Then we correct their biases using:

\[ \hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t} \]

And finally the update step is performed using the following rule:

\[ \theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{v_t}} + \epsilon} \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\epsilon > 0\).

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

betas

coefficients used for running averages

TYPE: tuple[float, float] DEFAULT: (0.9, 0.999)

eps

term added to the denominator to improve numerical stability

TYPE: float DEFAULT: 1e-08

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

dof

degrees of freedom

TYPE: float | None DEFAULT: None

Source code in holocron/optim/tadam.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    amsgrad: bool = False,
    dof: float | None = None,
) -> None:
    if lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if eps < 0.0:
        raise ValueError(f"Invalid epsilon value: {eps}")
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
    if not weight_decay >= 0.0:
        raise ValueError(f"Invalid weight_decay value: {weight_decay}")
    defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "amsgrad": amsgrad, "dof": dof}
    super().__init__(params, defaults)

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/tadam.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        params_with_grad = []
        grads = []
        exp_avgs = []
        exp_avg_sqs = []
        W_ts = []  # noqa: N806
        max_exp_avg_sqs = []
        state_steps = []

        beta1, beta2 = group["betas"]

        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
                grads.append(p.grad)

                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if group["amsgrad"]:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Tadam specific
                    state["W_t"] = beta1 / (1 - beta1) * torch.ones(1, dtype=p.data.dtype, device=p.data.device)

                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])
                W_ts.append(state["W_t"])

                if group["amsgrad"]:
                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])

                # update the steps for each param group update
                state["step"] += 1
                # record the step after step update
                state_steps.append(state["step"])

        tadam(
            params_with_grad,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            W_ts,
            state_steps,
            group["amsgrad"],
            beta1,
            beta2,
            group["lr"],
            group["weight_decay"],
            group["eps"],
            group["dof"],
        )

    return loss

AdaBelief

Bases: Adam

Implements the AdaBelief optimizer from "AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients".

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[ m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon \]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = 0,\ s_0 = 0\), \(\epsilon > 0\).

Then we correct their biases using:

\[ \hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t} \]

And finally the update step is performed using the following rule:

\[ \theta_t \leftarrow \theta_{t-1} - \alpha \frac{\hat{m_t}}{\sqrt{\hat{s_t}} + \epsilon} \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\epsilon > 0\).

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

lr

learning rate

betas

coefficients used for running averages

eps

term added to the denominator to improve numerical stability

weight_decay

weight decay (L2 penalty)

amsgrad

whether to use the AMSGrad variant

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/adabelief.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model
            and returns the loss.

    Returns:
        loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        params_with_grad = []
        grads = []
        exp_avgs = []
        exp_avg_sqs = []
        max_exp_avg_sqs = []
        state_steps = []

        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
                grads.append(p.grad)

                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if group["amsgrad"]:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])

                if group["amsgrad"]:
                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])

                # update the steps for each param group update
                state["step"] += 1
                # record the step after step update
                state_steps.append(state["step"])

        beta1, beta2 = group["betas"]
        adabelief(
            params_with_grad,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            group["amsgrad"],
            beta1,
            beta2,
            group["lr"],
            group["weight_decay"],
            group["eps"],
        )
    return loss

AdamP

AdamP(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False, delta: float = 0.1)

Bases: Adam

Implements the AdamP optimizer from "AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights".

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[ m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2 \in [0, 1]^2\) are the exponential average smoothing coefficients, \(m_0 = g_0,\ v_0 = 0\).

Then we correct their biases using:

\[ \hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t} \]

And finally the update step is performed using the following rule:

\[ p_t \leftarrow \frac{\hat{m_t}}{\sqrt{\hat{n_t} + \epsilon}} \\ q_t \leftarrow \begin{cases} \prod_{\theta_t}(p_t) & if\ cos(\theta_t, g_t) < \delta / \sqrt{dim(\theta)}\\ p_t & \text{otherwise}\\ \end{cases} \\ \theta_t \leftarrow \theta_{t-1} - \alpha q_t \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\prod_{\theta_t}(p_t)\) is the projection of \(p_t\) onto the tangent space of \(\theta_t\), \(cos(\theta_t, g_t)\) is the cosine similarity between \(\theta_t\) and \(g_t\), \(\alpha\) is the learning rate, \(\delta > 0\), \(\epsilon > 0\).

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

betas

coefficients used for running averages

TYPE: tuple[float, float] DEFAULT: (0.9, 0.999)

eps

term added to the denominator to improve numerical stability

TYPE: float DEFAULT: 1e-08

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

amsgrad

whether to use the AMSGrad variant

TYPE: bool DEFAULT: False

delta

delta threshold for projection

TYPE: float DEFAULT: 0.1

Source code in holocron/optim/adamp.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    amsgrad: bool = False,
    delta: float = 0.1,
) -> None:
    super().__init__(params, lr, betas, eps, weight_decay, amsgrad)
    self.delta = delta

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/adamp.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        params_with_grad = []
        grads = []
        exp_avgs = []
        exp_avg_sqs = []
        max_exp_avg_sqs = []
        state_steps = []

        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
                grads.append(p.grad)

                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if group["amsgrad"]:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])
                if group["amsgrad"]:
                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])

                # update the steps for each param group update
                state["step"] += 1
                # record the step after step update
                state_steps.append(state["step"])

        beta1, beta2 = group["betas"]
        adamp(
            params_with_grad,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            group["amsgrad"],
            beta1,
            beta2,
            group["lr"],
            group["weight_decay"],
            group["eps"],
            self.delta,
        )

    return loss

Adan

Adan(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[float, float, float] = (0.98, 0.92, 0.99), eps: float = 1e-08, weight_decay: float = 0.0, amsgrad: bool = False)

Bases: Adam

Implements the Adan optimizer from "Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models".

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[ m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1}) \\ n_t \leftarrow \beta_3 n_{t-1} + (1 - \beta_3) [g_t + \beta_2 (g_t - g_{t - 1})]^2 \]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2, \beta_3 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_0 = g_0,\ v_0 = 0,\ n_0 = g_0^2\).

Then we correct their biases using:

\[ \hat{m_t} \leftarrow \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} \leftarrow \frac{v_t}{1 - \beta_2^t} \\ \hat{n_t} \leftarrow \frac{n_t}{1 - \beta_3^t} \]

And finally the update step is performed using the following rule:

\[ p_t \leftarrow \frac{\hat{m_t} + (1 - \beta_2) \hat{v_t}}{\sqrt{\hat{n_t} + \epsilon}} \\ \theta_t \leftarrow \frac{\theta_{t-1} - \alpha p_t}{1 + \lambda \alpha} \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\alpha\) is the learning rate, \(\lambda \geq 0\) is the weight decay, \(\epsilon > 0\).

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

betas

coefficients used for running averages

TYPE: tuple[float, float, float] DEFAULT: (0.98, 0.92, 0.99)

eps

term added to the denominator to improve numerical stability

TYPE: float DEFAULT: 1e-08

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

amsgrad

whether to use the AMSGrad variant

TYPE: bool DEFAULT: False

Source code in holocron/optim/adan.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    betas: tuple[float, float, float] = (0.98, 0.92, 0.99),
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    amsgrad: bool = False,
) -> None:
    super().__init__(params, lr, betas, eps, weight_decay, amsgrad)  # type: ignore[arg-type]

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/adan.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        params_with_grad = []
        grads = []
        prev_grads = []
        exp_avgs = []
        exp_avg_sqs = []
        exp_avg_deltas = []
        max_exp_avg_deltas = []
        state_steps = []

        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
                grads.append(p.grad)

                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of gradient delta values
                    state["exp_avg_delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if group["amsgrad"]:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_exp_avg_delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["prev_grad"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                prev_grads.append(state["prev_grad"])
                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sq"])
                exp_avg_deltas.append(state["exp_avg_delta"])
                if group["amsgrad"]:
                    max_exp_avg_deltas.append(state["max_exp_avg_delta"])

                # update the steps for each param group update
                state["step"] += 1
                # record the step after step update
                state_steps.append(state["step"])

        beta1, beta2, beta3 = group["betas"]
        adan(
            params_with_grad,
            grads,
            prev_grads,
            exp_avgs,
            exp_avg_sqs,
            exp_avg_deltas,
            max_exp_avg_deltas,
            state_steps,
            group["amsgrad"],
            beta1,
            beta2,
            beta3,
            group["lr"],
            group["weight_decay"],
            group["eps"],
        )

    return loss

AdEMAMix

AdEMAMix(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[float, float, float] = (0.9, 0.999, 0.9999), alpha: float = 5.0, eps: float = 1e-08, weight_decay: float = 0.0)

Bases: Optimizer

Implements the AdEMAMix optimizer from "The AdEMAMix Optimizer: Better, Faster, Older".

The estimation of momentums is described as follows, \(\forall t \geq 1\):

\[ m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\ m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\ s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon \]

where \(g_t\) is the gradient of \(\theta_t\), \(\beta_1, \beta_2, \beta_3 \in [0, 1]^3\) are the exponential average smoothing coefficients, \(m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0\), \(\epsilon > 0\).

Then we correct their biases using:

\[ \hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\ \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t} \]

And finally the update step is performed using the following rule:

\[ \theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon} \]

where \(\theta_t\) is the parameter value at step \(t\) (\(\theta_0\) being the initialization value), \(\eta\) is the learning rate, \(\alpha > 0\) \(\epsilon > 0\).

PARAMETER DESCRIPTION
params

iterable of parameters to optimize or dicts defining parameter groups

TYPE: Iterable[Parameter]

lr

learning rate

TYPE: float DEFAULT: 0.001

betas

coefficients used for running averages

TYPE: tuple[float, float, float] DEFAULT: (0.9, 0.999, 0.9999)

alpha

the exponential decay rate of the second moment estimates

TYPE: float DEFAULT: 5.0

eps

term added to the denominator to improve numerical stability

TYPE: float DEFAULT: 1e-08

weight_decay

weight decay (L2 penalty)

TYPE: float DEFAULT: 0.0

Source code in holocron/optim/ademamix.py
def __init__(
    self,
    params: Iterable[torch.nn.Parameter],
    lr: float = 1e-3,
    betas: tuple[float, float, float] = (0.9, 0.999, 0.9999),
    alpha: float = 5.0,
    eps: float = 1e-8,
    weight_decay: float = 0.0,
) -> None:
    if lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if eps < 0.0:
        raise ValueError(f"Invalid epsilon value: {eps}")
    for idx, beta in enumerate(betas):
        if not 0.0 <= beta < 1.0:
            raise ValueError(f"Invalid beta parameter at index {idx}: {beta}")
    defaults = {"lr": lr, "betas": betas, "alpha": alpha, "eps": eps, "weight_decay": weight_decay}
    super().__init__(params, defaults)

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: callable DEFAULT: None

RETURNS DESCRIPTION
float | None

float | None: loss value

RAISES DESCRIPTION
RuntimeError

if the optimizer does not support sparse gradients

Source code in holocron/optim/ademamix.py
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.

    Returns:
        float | None: loss value

    Raises:
        RuntimeError: if the optimizer does not support sparse gradients
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        params_with_grad = []
        grads = []
        exp_avgs = []
        exp_avgs_slow = []
        exp_avg_sqs = []
        state_steps = []

        for p in group["params"]:
            if p.grad is not None:
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
                grads.append(p.grad)

                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["exp_avg_slow"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avgs.append(state["exp_avg"])
                exp_avgs_slow.append(state["exp_avg_slow"])
                exp_avg_sqs.append(state["exp_avg_sq"])

                # update the steps for each param group update
                state["step"] += 1
                # record the step after step update
                state_steps.append(state["step"])

        beta1, beta2, beta3 = group["betas"]
        ademamix(
            params_with_grad,
            grads,
            exp_avgs,
            exp_avgs_slow,
            exp_avg_sqs,
            state_steps,
            beta1,
            beta2,
            beta3,
            group["alpha"],
            group["lr"],
            group["weight_decay"],
            group["eps"],
        )
    return loss

Optimizer wrappers

holocron.optim also implements optimizer wrappers.

A base optimizer should always be passed to the wrapper; e.g., you should write your code this way:

>>> optimizer = ...
>>> optimizer = wrapper(optimizer)

Lookahead

Lookahead(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)

Bases: Optimizer

Implements the Lookahead optimizer wrapper from "Lookahead Optimizer: k steps forward, 1 step back" <https://arxiv.org/pdf/1907.08610.pdf>_.

from torch.optim import AdamW from holocron.optim.wrapper import Lookahead model = ... opt = AdamW(model.parameters(), lr=3e-4) opt_wrapper = Lookahead(opt)

PARAMETER DESCRIPTION
base_optimizer

base parameter optimizer

TYPE: Optimizer

sync_rate

rate of weight synchronization

TYPE: float DEFAULT: 0.5

sync_period

number of step performed on fast weights before weight synchronization

TYPE: int DEFAULT: 6

Source code in holocron/optim/wrapper.py
def __init__(
    self,
    base_optimizer: torch.optim.Optimizer,
    sync_rate: float = 0.5,
    sync_period: int = 6,
) -> None:
    if sync_rate < 0 or sync_rate > 1:
        raise ValueError(f"expected positive float lower than 1 as sync_rate, received: {sync_rate}")
    if not isinstance(sync_period, int) or sync_period < 1:
        raise ValueError(f"expected positive integer as sync_period, received: {sync_period}")
    # Optimizer attributes
    self.defaults = {"sync_rate": sync_rate, "sync_period": sync_period}
    self.state = defaultdict(dict)
    # Base optimizer attributes
    self.base_optimizer = base_optimizer
    # Wrapper attributes
    self.fast_steps = 0
    self.param_groups = []
    for group in self.base_optimizer.param_groups:
        self._add_param_group(group)

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

Source code in holocron/optim/wrapper.py
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value
    """
    # Update fast params
    loss = self.base_optimizer.step(closure)
    self.fast_steps += 1
    # Synchronization every sync_period steps on fast params
    if self.fast_steps % self.defaults["sync_period"] == 0:
        self.sync_params(self.defaults["sync_rate"])

    return loss

add_param_group

add_param_group(param_group: dict[str, Any]) -> None

Adds a parameter group to base optimizer (fast weights) and its corresponding slow version

PARAMETER DESCRIPTION
param_group

parameter group

TYPE: dict[str, Any]

Source code in holocron/optim/wrapper.py
def add_param_group(self, param_group: dict[str, Any]) -> None:
    """Adds a parameter group to base optimizer (fast weights) and its corresponding slow version

    Args:
        param_group: parameter group
    """
    # Add param group to base optimizer
    self.base_optimizer.add_param_group(param_group)

    # Add the corresponding slow param group
    self._add_param_group(self.base_optimizer.param_groups[-1])

sync_params

sync_params(sync_rate: float = 0.0) -> None

Synchronize parameters as follows: slow_param <- slow_param + sync_rate * (fast_param - slow_param)

PARAMETER DESCRIPTION
sync_rate

synchronization rate of parameters

TYPE: float DEFAULT: 0.0

Source code in holocron/optim/wrapper.py
def sync_params(self, sync_rate: float = 0.0) -> None:
    """Synchronize parameters as follows:
    slow_param <- slow_param + sync_rate * (fast_param - slow_param)

    Args:
        sync_rate: synchronization rate of parameters
    """
    for fast_group, slow_group in zip(self.base_optimizer.param_groups, self.param_groups, strict=True):
        for fast_p, slow_p in zip(fast_group["params"], slow_group["params"], strict=True):
            # Outer update
            if sync_rate > 0:
                slow_p.data.add_(fast_p.data - slow_p.data, alpha=sync_rate)
            # Synchronize fast and slow params
            fast_p.data.copy_(slow_p.data)

Scout

Scout(base_optimizer: Optimizer, sync_rate: float = 0.5, sync_period: int = 6)

Bases: Optimizer

Implements a new optimizer wrapper based on "Lookahead Optimizer: k steps forward, 1 step back".

Example

from torch.optim import AdamW from holocron.optim.wrapper import Scout model = ... opt = AdamW(model.parameters(), lr=3e-4) opt_wrapper = Scout(opt)

PARAMETER DESCRIPTION
base_optimizer

base parameter optimizer

TYPE: Optimizer

sync_rate

rate of weight synchronization

TYPE: float DEFAULT: 0.5

sync_period

number of step performed on fast weights before weight synchronization

TYPE: int DEFAULT: 6

Source code in holocron/optim/wrapper.py
def __init__(
    self,
    base_optimizer: torch.optim.Optimizer,
    sync_rate: float = 0.5,
    sync_period: int = 6,
) -> None:
    if sync_rate < 0 or sync_rate > 1:
        raise ValueError(f"expected positive float lower than 1 as sync_rate, received: {sync_rate}")
    if not isinstance(sync_period, int) or sync_period < 1:
        raise ValueError(f"expected positive integer as sync_period, received: {sync_period}")
    # Optimizer attributes
    self.defaults = {"sync_rate": sync_rate, "sync_period": sync_period}
    self.state = defaultdict(dict)
    # Base optimizer attributes
    self.base_optimizer = base_optimizer
    # Wrapper attributes
    self.fast_steps = 0
    self.param_groups = []
    for group in self.base_optimizer.param_groups:
        self._add_param_group(group)
    # Buffer for scouting
    self.buffer = [p.data.unsqueeze(0) for group in self.param_groups for p in group["params"]]

step

step(closure: Callable[[], float] | None = None) -> float | None

Performs a single optimization step.

PARAMETER DESCRIPTION
closure

A closure that reevaluates the model and returns the loss.

TYPE: Callable[[], float] | None DEFAULT: None

RETURNS DESCRIPTION
float | None

loss value

Source code in holocron/optim/wrapper.py
def step(self, closure: Callable[[], float] | None = None) -> float | None:  # type: ignore[override]
    """Performs a single optimization step.

    Arguments:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        loss value
    """
    # Update fast params
    loss = self.base_optimizer.step(closure)
    self.fast_steps += 1
    # Add it to buffer
    idx = 0
    for group in self.base_optimizer.param_groups:
        for p in group["params"]:
            self.buffer[idx] = torch.cat((self.buffer[idx], p.data.clone().detach().unsqueeze(0)))
            idx += 1
    # Synchronization every sync_period steps on fast params
    if self.fast_steps % self.defaults["sync_period"] == 0:
        # Compute STD of updates
        update_similarity = []
        for _ in range(len(self.buffer)):
            p = self.buffer.pop()
            update = p[1:] - p[:-1]
            max_dev = (update - torch.mean(update, dim=0)).abs().max(dim=0).values
            update_similarity.append((torch.std(update, dim=0) / max_dev).mean().item())
        update_coherence = sum(update_similarity) / len(update_similarity)

        sync_rate = max(1 - update_coherence, self.defaults["sync_rate"])
        # sync_rate = self.defaults['sync_rate']
        self.sync_params(sync_rate)
        # Reset buffer
        self.buffer = []
        for group in self.param_groups:
            for p in group["params"]:
                self.buffer.append(p.data.unsqueeze(0))

    return loss

add_param_group

add_param_group(param_group: dict[str, Any]) -> None

Adds a parameter group to base optimizer (fast weights) and its corresponding slow version

PARAMETER DESCRIPTION
param_group

parameter group

TYPE: dict[str, Any]

Source code in holocron/optim/wrapper.py
def add_param_group(self, param_group: dict[str, Any]) -> None:
    """Adds a parameter group to base optimizer (fast weights) and its corresponding slow version

    Args:
        param_group: parameter group
    """
    # Add param group to base optimizer
    self.base_optimizer.add_param_group(param_group)

    # Add the corresponding slow param group
    self._add_param_group(self.base_optimizer.param_groups[-1])

sync_params

sync_params(sync_rate: float = 0.0) -> None

Synchronize parameters as follows: slow_param <- slow_param + sync_rate * (fast_param - slow_param)

PARAMETER DESCRIPTION
sync_rate

synchronization rate of parameters

TYPE: float DEFAULT: 0.0

Source code in holocron/optim/wrapper.py
def sync_params(self, sync_rate: float = 0.0) -> None:
    """Synchronize parameters as follows:
    slow_param <- slow_param + sync_rate * (fast_param - slow_param)

    Args:
        sync_rate: synchronization rate of parameters
    """
    for fast_group, slow_group in zip(self.base_optimizer.param_groups, self.param_groups, strict=True):
        for fast_p, slow_p in zip(fast_group["params"], slow_group["params"], strict=True):
            # Outer update
            if sync_rate > 0:
                slow_p.data.add_(fast_p.data - slow_p.data, alpha=sync_rate)
            # Synchronize fast and slow params
            fast_p.data.copy_(slow_p.data)