Source code for holocron.optim.wrapper

# -*- coding: utf-8 -*-

'''
Lookahead optimizer wrapper
'''

import torch
from collections import defaultdict
from torch.optim.optimizer import Optimizer


__all__ = ['Lookahead', 'Scout']


[docs] class Lookahead(Optimizer): """Implements the Lookahead optimizer wrapper from `"Lookahead Optimizer: k steps forward, 1 step back" <https://arxiv.org/pdf/1907.08610.pdf>`_. Args: base_optimizer (torch.optim.optimizer.Optimizer): base parameter optimizer sync_rate (int, optional): rate of weight synchronization sync_period (int, optional): number of step performed on fast weights before weight synchronization """ def __init__(self, base_optimizer, sync_rate=0.5, sync_period=6): if sync_rate < 0 or sync_rate > 1: raise ValueError(f'expected positive float lower than 1 as sync_rate, received: {sync_rate}') if not isinstance(sync_period, int) or sync_period < 1: raise ValueError(f'expected positive integer as sync_period, received: {sync_period}') # Optimizer attributes self.defaults = dict(sync_rate=sync_rate, sync_period=sync_period) self.state = defaultdict(dict) # Base optimizer attributes self.base_optimizer = base_optimizer # Wrapper attributes self.fast_steps = 0 self.param_groups = [] for group in self.base_optimizer.param_groups: self._add_param_group(group) def __getstate__(self): return { 'defaults': self.defaults, 'state': self.state, 'base_state': self.base_optimizer.__getstate__(), 'fast_steps': self.fast_steps, 'param_groups': self.param_groups } def state_dict(self): return dict(**super(Lookahead, self).state_dict(), base_state_dict=self.base_optimizer.state_dict()) def load_state_dict(self, state_dict): self.base_optimizer.load_state_dict(state_dict['base_state_dict']) super(Lookahead, self).load_state_dict(state_dict) # Update last key of class dict self.__setstate__({'base_state_dict': self.base_optimizer.state_dict()}) def zero_grad(self): self.base_optimizer.zero_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ # Update fast params loss = self.base_optimizer.step(closure) self.fast_steps += 1 # Synchronization every sync_period steps on fast params if self.fast_steps % self.defaults['sync_period'] == 0: self.sync_params(self.defaults['sync_rate']) return loss def __repr__(self): format_string = self.__class__.__name__ + ' (' optimizer_repr = self.base_optimizer.__repr__().replace('\n', '\n\t') format_string += f"\nbase_optimizer={optimizer_repr}," for arg, val in self.defaults.items(): format_string += f"\n{arg}={val}," format_string += '\n)' return format_string def _add_param_group(self, param_group): """Adds a new slow parameter group Args: param_group (dict): parameter group of base_optimizer """ # Clone & detach params from base optimizer group = dict(params=[p.clone().detach() for p in param_group['params']], lr=param_group['lr']) # Uneeded grads for p in group['params']: p.reguires_grad = False self.param_groups.append(group) def add_param_group(self, param_group): """Adds a parameter group to base optimizer (fast weights) and its corresponding slow version Args: param_group (dict): parameter group """ # Add param group to base optimizer self.base_optimizer.add_param_group(param_group) # Add the corresponding slow param group self._add_param_group(self.base_optimizer.param_groups[-1]) def sync_params(self, sync_rate=0): """Synchronize parameters as follows: slow_param <- slow_param + sync_rate * (fast_param - slow_param) Args: sync_rate (float): synchronization rate of parameters """ for fast_group, slow_group in zip(self.base_optimizer.param_groups, self.param_groups): for fast_p, slow_p in zip(fast_group['params'], slow_group['params']): # Outer update if sync_rate > 0: slow_p.data.add_(sync_rate, fast_p.data - slow_p.data) # Synchronize fast and slow params fast_p.data.copy_(slow_p.data)
[docs] class Scout(Optimizer): """Implements a new optimizer wrapper based on `"Lookahead Optimizer: k steps forward, 1 step back" <https://arxiv.org/pdf/1907.08610.pdf>`_. Args: base_optimizer (torch.optim.optimizer.Optimizer): base parameter optimizer sync_rate (int, optional): rate of weight synchronization sync_period (int, optional): number of step performed on fast weights before weight synchronization """ def __init__(self, base_optimizer, sync_rate=0.5, sync_period=6): if sync_rate < 0 or sync_rate > 1: raise ValueError(f'expected positive float lower than 1 as sync_rate, received: {sync_rate}') if not isinstance(sync_period, int) or sync_period < 1: raise ValueError(f'expected positive integer as sync_period, received: {sync_period}') # Optimizer attributes self.defaults = dict(sync_rate=sync_rate, sync_period=sync_period) self.state = defaultdict(dict) # Base optimizer attributes self.base_optimizer = base_optimizer # Wrapper attributes self.fast_steps = 0 self.param_groups = [] for group in self.base_optimizer.param_groups: self._add_param_group(group) # Buffer for scouting self.buffer = [] for group in self.param_groups: for p in group['params']: self.buffer.append(p.data.unsqueeze(0)) def __getstate__(self): return { 'defaults': self.defaults, 'state': self.state, 'base_state': self.base_optimizer.__getstate__(), 'fast_steps': self.fast_steps, 'param_groups': self.param_groups } def state_dict(self): return dict(**super(Scout, self).state_dict(), base_state_dict=self.base_optimizer.state_dict()) def load_state_dict(self, state_dict): self.base_optimizer.load_state_dict(state_dict['base_state_dict']) super(Scout, self).load_state_dict(state_dict) # Update last key of class dict self.__setstate__({'base_state_dict': self.base_optimizer.state_dict()}) def zero_grad(self): self.base_optimizer.zero_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ # Update fast params loss = self.base_optimizer.step(closure) self.fast_steps += 1 # Add it to buffer idx = 0 for group in self.base_optimizer.param_groups: for p in group['params']: self.buffer[idx] = torch.cat((self.buffer[idx], p.data.clone().detach().unsqueeze(0))) idx += 1 # Synchronization every sync_period steps on fast params if self.fast_steps % self.defaults['sync_period'] == 0: # Compute STD of updates update_similarity = [] for _ in range(len(self.buffer)): p = self.buffer.pop() update = p[1:] - p[:-1] max_dev = (update - torch.mean(update, dim=0)).abs().max(dim=0).values update_similarity.append((torch.std(update, dim=0) / max_dev).mean().item()) update_coherence = sum(std_list) / len(std_list) sync_rate = max(1 - update_coherence, self.defaults['sync_rate']) # sync_rate = self.defaults['sync_rate'] self.sync_params(sync_rate) # Reset buffer self.buffer = [] for group in self.param_groups: for p in group['params']: self.buffer.append(p.data.unsqueeze(0)) return loss def __repr__(self): format_string = self.__class__.__name__ + ' (' optimizer_repr = self.base_optimizer.__repr__().replace('\n', '\n\t') format_string += f"\nbase_optimizer={optimizer_repr}," for arg, val in self.defaults.items(): format_string += f"\n{arg}={val}," format_string += '\n)' return format_string def _add_param_group(self, param_group): """Adds a new slow parameter group Args: param_group (dict): parameter group of base_optimizer """ # Clone & detach params from base optimizer group = dict(params=[p.clone().detach() for p in param_group['params']], lr=param_group['lr']) # Uneeded grads for p in group['params']: p.reguires_grad = False self.param_groups.append(group) def add_param_group(self, param_group): """Adds a parameter group to base optimizer (fast weights) and its corresponding slow version Args: param_group (dict): parameter group """ # Add param group to base optimizer self.base_optimizer.add_param_group(param_group) # Add the corresponding slow param group self._add_param_group(self.base_optimizer.param_groups[-1]) def sync_params(self, sync_rate=0): """Synchronize parameters as follows: slow_param <- slow_param + sync_rate * (fast_param - slow_param) Args: sync_rate (float): synchronization rate of parameters """ for fast_group, slow_group in zip(self.base_optimizer.param_groups, self.param_groups): for fast_p, slow_p in zip(fast_group['params'], slow_group['params']): # Outer update if sync_rate > 0: slow_p.data.add_(sync_rate, fast_p.data - slow_p.data) # Synchronize fast and slow params fast_p.data.copy_(slow_p.data)