Source code for holocron.optim.lr_scheduler

# -*- coding: utf-8 -*-


"""
Optimization schedulers
"""

import math
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer


__all__ = ['OneCycleScheduler']


[docs] class OneCycleScheduler(_LRScheduler): """Implements the One Cycle scheduler from https://arxiv.org/pdf/1803.09820.pdf Args: optimizer (Optimizer): Wrapped optimizer. total_size (int): Number of training iterations to be performed max_lr (float or list): Upper learning rate boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function. warmup_ratio (float): ratio of iterations used to reach max_lr phases (tuple): specify the scaling mode of both phases (possible values: 'linear', 'cosine') base_ratio (float): ratio between base_lr and max_lr during warmup phase final_ratio (float): ratio between base_lr and max_lr during last phase cycle_momentum (bool): If ``True``, momentum is cycled inversely to learning rate between 'base_momentum' and 'max_momentum'. Default: True base_momentum (float or list): Lower momentum boundaries in the cycle for each parameter group. Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is 'base_momentum' and learning rate is 'max_lr'. Default: 0.8 max_momentum (float or list): Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_momentum - base_momentum). The momentum at any cycle is the difference of max_momentum and some scaling of the amplitude; therefore base_momentum may not actually be reached depending on scaling function. Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' and learning rate is 'base_lr' Default: 0.9 last_epoch (int): The index of the last batch. This parameter is used when resuming a training job. Since `step()` should be invoked after each batch instead of after each epoch, this number represents the total number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 """ def __init__(self, optimizer, total_size, max_lr=None, warmup_ratio=0.3, phases=None, base_ratio=0.2, final_ratio=None, cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1): if not isinstance(optimizer, Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) self.optimizer = optimizer # Specify max lr if isinstance(max_lr, float): self.max_lrs = [max_lr for group in optimizer.param_groups] elif isinstance(max_lr, (list, tuple)): if len(max_lr) != len(optimizer.param_groups): raise ValueError(f"expected {len(optimizer.param_groups)} values for max_lr, got {len(max_lr)}") self.max_lrs = max_lr else: # Take current value as max_lr self.max_lrs = [group['lr'] for group in optimizer.param_groups] # Take the division factor for each phase self.base_ratio = base_ratio self.final_ratio = base_ratio * 1e-4 if final_ratio is None else final_ratio self.total_size = total_size self.warmup_ratio = warmup_ratio # Phases self.phases = phases if isinstance(phases, tuple) else ('linear', 'cosine') modes = ['linear', 'cosine'] if any(phase not in modes for phase in self.phases): raise ValueError(f"Phases can only take values from {modes}") # Handle momentum for specific optimizer self.cycle_momentum = cycle_momentum if cycle_momentum: if 'momentum' not in optimizer.defaults: raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') base_momentums = self._format_param('base_momentum', optimizer, base_momentum) if last_epoch == -1: for momentum, group in zip(base_momentums, optimizer.param_groups): group['momentum'] = momentum self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups)) self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) super(OneCycleScheduler, self).__init__(optimizer, last_epoch) def _format_param(self, name, optimizer, param): """Return correctly formatted lr/momentum for each param group.""" if isinstance(param, (list, tuple)): if len(param) != len(optimizer.param_groups): raise ValueError("expected {} values for {}, got {}".format( len(optimizer.param_groups), name, len(param))) return param else: return [param] * len(optimizer.param_groups) def get_lr(self): """Calculates the learning rate at batch index. This function treats `self.last_epoch` as the last batch index. If `self.cycle_momentum` is ``True``, this function has a side effect of updating the optimizer's momentum. """ step_ratio = self.last_epoch / self.total_size # Get phase progress and LR divider for current phase if step_ratio <= self.warmup_ratio: phase_idx = 0 x = step_ratio / self.warmup_ratio base_ratio = self.base_ratio else: phase_idx = 1 x = (step_ratio - self.warmup_ratio) / (1 - self.warmup_ratio) base_ratio = self.final_ratio # Adapt scaling based on phase mode if self.phases[phase_idx] == 'linear': scale_factor = x elif self.phases[phase_idx] == 'cosine': scale_factor = 0.5 * (1 + math.cos(x * math.pi)) # Populate LR for each group lrs = [] for max_lr in self.max_lrs: base_lr = base_ratio * max_lr base_height = (max_lr - base_lr) * scale_factor lrs.append(base_lr + base_height) # Populate momentum for each group if self.cycle_momentum: momentums = [] for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): base_momentum = base_ratio * max_momentum base_height = (max_momentum - base_momentum) * scale_factor momentums.append(max_momentum - base_height) for param_group, momentum in zip(self.optimizer.param_groups, momentums): param_group['momentum'] = momentum return lrs def __repr__(self): return (f"{self.__class__.__name__}(max_lr={max(self.max_lrs)}, warmup_ratio={self.warmup_ratio}, " f"base_ratio={self.base_ratio}, final_ratio={self.final_ratio}, phases={self.phases})")