Source code for holocron.optim.wrapper

# Copyright (C) 2019-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from collections import defaultdict
from typing import Any, Callable, Dict, Optional

import torch
from torch.optim.optimizer import Optimizer

__all__ = ["Lookahead", "Scout"]


[docs] class Lookahead(Optimizer): """Implements the Lookahead optimizer wrapper from `"Lookahead Optimizer: k steps forward, 1 step back" <https://arxiv.org/pdf/1907.08610.pdf>`_. >>> from torch.optim import AdamW >>> from holocron.optim.wrapper import Lookahead >>> model = ... >>> opt = AdamW(model.parameters(), lr=3e-4) >>> opt_wrapper = Lookahead(opt) Args: base_optimizer (torch.optim.optimizer.Optimizer): base parameter optimizer sync_rate (int, optional): rate of weight synchronization sync_period (int, optional): number of step performed on fast weights before weight synchronization """ def __init__( self, base_optimizer: torch.optim.Optimizer, sync_rate: float = 0.5, sync_period: int = 6, ) -> None: if sync_rate < 0 or sync_rate > 1: raise ValueError(f"expected positive float lower than 1 as sync_rate, received: {sync_rate}") if not isinstance(sync_period, int) or sync_period < 1: raise ValueError(f"expected positive integer as sync_period, received: {sync_period}") # Optimizer attributes self.defaults = {"sync_rate": sync_rate, "sync_period": sync_period} self.state = defaultdict(dict) # Base optimizer attributes self.base_optimizer = base_optimizer # Wrapper attributes self.fast_steps = 0 self.param_groups = [] for group in self.base_optimizer.param_groups: self._add_param_group(group) def __getstate__(self) -> Dict[str, Any]: return { "defaults": self.defaults, "state": self.state, "base_state": self.base_optimizer.__getstate__(), "fast_steps": self.fast_steps, "param_groups": self.param_groups, } def state_dict(self) -> Dict[str, Any]: return dict(**super(Lookahead, self).state_dict(), base_state_dict=self.base_optimizer.state_dict()) def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.base_optimizer.load_state_dict(state_dict["base_state_dict"]) super(Lookahead, self).load_state_dict(state_dict) # Update last key of class dict self.__setstate__({"base_state_dict": self.base_optimizer.state_dict()}) def zero_grad(self, set_to_none: bool = True) -> None: self.base_optimizer.zero_grad(set_to_none) def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ # Update fast params loss = self.base_optimizer.step(closure) self.fast_steps += 1 # Synchronization every sync_period steps on fast params if self.fast_steps % self.defaults["sync_period"] == 0: self.sync_params(self.defaults["sync_rate"]) return loss def __repr__(self) -> str: format_string = self.__class__.__name__ + " (" optimizer_repr = self.base_optimizer.__repr__().replace("\n", "\n\t") format_string += f"\nbase_optimizer={optimizer_repr}," for arg, val in self.defaults.items(): format_string += f"\n{arg}={val}," format_string += "\n)" return format_string def _add_param_group(self, param_group: Dict[str, Any]) -> None: """Adds a new slow parameter group Args: param_group (dict): parameter group of base_optimizer """ # Clone & detach params from base optimizer group = {"params": [p.clone().detach() for p in param_group["params"]], "lr": param_group["lr"]} # Uneeded grads for p in group["params"]: p.reguires_grad = False self.param_groups.append(group) def add_param_group(self, param_group: Dict[str, Any]) -> None: """Adds a parameter group to base optimizer (fast weights) and its corresponding slow version Args: param_group (dict): parameter group """ # Add param group to base optimizer self.base_optimizer.add_param_group(param_group) # Add the corresponding slow param group self._add_param_group(self.base_optimizer.param_groups[-1]) def sync_params(self, sync_rate: float = 0.0) -> None: """Synchronize parameters as follows: slow_param <- slow_param + sync_rate * (fast_param - slow_param) Args: sync_rate (float): synchronization rate of parameters """ for fast_group, slow_group in zip(self.base_optimizer.param_groups, self.param_groups): for fast_p, slow_p in zip(fast_group["params"], slow_group["params"]): # Outer update if sync_rate > 0: slow_p.data.add_(fast_p.data - slow_p.data, alpha=sync_rate) # Synchronize fast and slow params fast_p.data.copy_(slow_p.data)
[docs] class Scout(Optimizer): """Implements a new optimizer wrapper based on `"Lookahead Optimizer: k steps forward, 1 step back" <https://arxiv.org/pdf/1907.08610.pdf>`_. Example:: >>> from torch.optim import AdamW >>> from holocron.optim.wrapper import Scout >>> model = ... >>> opt = AdamW(model.parameters(), lr=3e-4) >>> opt_wrapper = Scout(opt) Args: base_optimizer (torch.optim.optimizer.Optimizer): base parameter optimizer sync_rate (float, optional): rate of weight synchronization sync_period (int, optional): number of step performed on fast weights before weight synchronization """ def __init__( self, base_optimizer: torch.optim.Optimizer, sync_rate: float = 0.5, sync_period: int = 6, ) -> None: if sync_rate < 0 or sync_rate > 1: raise ValueError(f"expected positive float lower than 1 as sync_rate, received: {sync_rate}") if not isinstance(sync_period, int) or sync_period < 1: raise ValueError(f"expected positive integer as sync_period, received: {sync_period}") # Optimizer attributes self.defaults = {"sync_rate": sync_rate, "sync_period": sync_period} self.state = defaultdict(dict) # Base optimizer attributes self.base_optimizer = base_optimizer # Wrapper attributes self.fast_steps = 0 self.param_groups = [] for group in self.base_optimizer.param_groups: self._add_param_group(group) # Buffer for scouting self.buffer = [p.data.unsqueeze(0) for group in self.param_groups for p in group["params"]] def __getstate__(self) -> Dict[str, Any]: return { "defaults": self.defaults, "state": self.state, "base_state": self.base_optimizer.__getstate__(), "fast_steps": self.fast_steps, "param_groups": self.param_groups, } def state_dict(self) -> Dict[str, Any]: return dict(**super(Scout, self).state_dict(), base_state_dict=self.base_optimizer.state_dict()) def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.base_optimizer.load_state_dict(state_dict["base_state_dict"]) super(Scout, self).load_state_dict(state_dict) # Update last key of class dict self.__setstate__({"base_state_dict": self.base_optimizer.state_dict()}) def zero_grad(self, set_to_none: bool = True) -> None: self.base_optimizer.zero_grad(set_to_none) def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ # Update fast params loss = self.base_optimizer.step(closure) self.fast_steps += 1 # Add it to buffer idx = 0 for group in self.base_optimizer.param_groups: for p in group["params"]: self.buffer[idx] = torch.cat((self.buffer[idx], p.data.clone().detach().unsqueeze(0))) idx += 1 # Synchronization every sync_period steps on fast params if self.fast_steps % self.defaults["sync_period"] == 0: # Compute STD of updates update_similarity = [] for _ in range(len(self.buffer)): p = self.buffer.pop() update = p[1:] - p[:-1] max_dev = (update - torch.mean(update, dim=0)).abs().max(dim=0).values update_similarity.append((torch.std(update, dim=0) / max_dev).mean().item()) update_coherence = sum(update_similarity) / len(update_similarity) sync_rate = max(1 - update_coherence, self.defaults["sync_rate"]) # sync_rate = self.defaults['sync_rate'] self.sync_params(sync_rate) # Reset buffer self.buffer = [] for group in self.param_groups: for p in group["params"]: self.buffer.append(p.data.unsqueeze(0)) return loss def __repr__(self) -> str: format_string = self.__class__.__name__ + " (" optimizer_repr = self.base_optimizer.__repr__().replace("\n", "\n\t") format_string += f"\nbase_optimizer={optimizer_repr}," for arg, val in self.defaults.items(): format_string += f"\n{arg}={val}," format_string += "\n)" return format_string def _add_param_group(self, param_group: Dict[str, Any]) -> None: """Adds a new slow parameter group Args: param_group (dict): parameter group of base_optimizer """ # Clone & detach params from base optimizer group = {"params": [p.clone().detach() for p in param_group["params"]], "lr": param_group["lr"]} # Uneeded grads for p in group["params"]: p.reguires_grad = False self.param_groups.append(group) def add_param_group(self, param_group: Dict[str, Any]) -> None: """Adds a parameter group to base optimizer (fast weights) and its corresponding slow version Args: param_group (dict): parameter group """ # Add param group to base optimizer self.base_optimizer.add_param_group(param_group) # Add the corresponding slow param group self._add_param_group(self.base_optimizer.param_groups[-1]) def sync_params(self, sync_rate: float = 0.0) -> None: """Synchronize parameters as follows: slow_param <- slow_param + sync_rate * (fast_param - slow_param) Args: sync_rate (float): synchronization rate of parameters """ for fast_group, slow_group in zip(self.base_optimizer.param_groups, self.param_groups): for fast_p, slow_p in zip(fast_group["params"], slow_group["params"]): # Outer update if sync_rate > 0: slow_p.data.add_(fast_p.data - slow_p.data, alpha=sync_rate) # Synchronize fast and slow params fast_p.data.copy_(slow_p.data)