Source code for holocron.trainer.core

# Copyright (C) 2019-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import math
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast

import matplotlib.pyplot as plt
import numpy as np
import torch
from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import ConsoleMasterBar
from torch import Tensor, nn
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler, MultiplicativeLR, OneCycleLR
from torch.utils.data import DataLoader

from .utils import freeze_bn, freeze_model, split_normalization_params

ParamSeq = Sequence[torch.nn.Parameter]

__all__ = ["Trainer"]


[docs] class Trainer: """Baseline trainer class. Args: model: model to train train_loader: training loader val_loader: validation loader criterion: loss criterion optimizer: parameter optimizer gpu: index of the GPU to use output_file: path where checkpoints will be saved amp: whether to use automatic mixed precision skip_nan_loss: whether the optimizer step should be skipped when the loss is NaN nan_tolerance: number of consecutive batches with NaN loss before stopping the training gradient_acc: number of batches to accumulate the gradient of before performing the update step gradient_clip: the gradient clip value on_epoch_end: callback triggered at the end of an epoch """ def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, # type: ignore[name-defined] gpu: Optional[int] = None, output_file: str = "./checkpoint.pth", amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, gradient_acc: int = 1, gradient_clip: Optional[float] = None, on_epoch_end: Optional[Callable[[Dict[str, float]], Any]] = None, ) -> None: self.model = model self.train_loader = train_loader self.val_loader = val_loader self.criterion = criterion self.optimizer = optimizer self.amp = amp self.scaler: torch.cuda.amp.grad_scaler.GradScaler self.on_epoch_end = on_epoch_end self.skip_nan_loss = skip_nan_loss self.nan_tolerance = nan_tolerance self.gradient_acc = gradient_acc self.grad_clip = gradient_clip # Output file self.output_file = output_file # Initialize self.step = 0 self.start_epoch = 0 self.epoch = 0 self._grad_count = 0 self.min_loss = math.inf self.gpu = gpu self._params: Tuple[ParamSeq, ParamSeq] = ([], []) self.lr_recorder: List[float] = [] self.loss_recorder: List[float] = [] self.set_device(gpu) self._reset_opt(self.optimizer.defaults["lr"])
[docs] def set_device(self, gpu: Optional[int] = None) -> None: """Move tensor objects to the target GPU Args: gpu: index of the target GPU device """ if isinstance(gpu, int): if not torch.cuda.is_available(): raise AssertionError("PyTorch cannot access your GPU. Please investigate!") if gpu >= torch.cuda.device_count(): raise ValueError("Invalid device index") torch.cuda.set_device(gpu) self.model = self.model.cuda() if isinstance(self.criterion, torch.nn.Module): self.criterion = self.criterion.cuda()
[docs] def save(self, output_file: str) -> None: """Save a trainer checkpoint Args: output_file: destination file path """ torch.save( { "epoch": self.epoch, "step": self.step, "min_loss": self.min_loss, "model": self.model.state_dict(), }, output_file, _use_new_zipfile_serialization=False, )
[docs] def load(self, state: Dict[str, Any]) -> None: """Resume from a trainer state Args: state (dict): checkpoint dictionary """ self.start_epoch = state["epoch"] self.epoch = self.start_epoch self.step = state["step"] self.min_loss = state["min_loss"] self.model.load_state_dict(state["model"])
def _fit_epoch(self, mb: ConsoleMasterBar) -> None: """Fit a single epoch Args: mb (fastprogress.master_bar): primary progress bar """ freeze_bn(self.model.train()) nan_cnt = 0 pb = progress_bar(self.train_loader, parent=mb) for x, target in pb: x, target = self.to_cuda(x, target) # Forward batch_loss: Tensor = self._get_loss(x, target) # type: ignore[assignment] # Backprop if not self.skip_nan_loss or torch.isfinite(batch_loss): nan_cnt = 0 self._backprop_step(batch_loss) else: nan_cnt += 1 if nan_cnt > self.nan_tolerance: raise ValueError(f"loss value has been NaN or inf for more than {self.nan_tolerance} steps.") # Update LR self.scheduler.step() pb.comment = f"Training loss: {batch_loss.item():.4}" self.step += 1 self.epoch += 1
[docs] def to_cuda( self, x: Tensor, target: Union[Tensor, List[Dict[str, Tensor]]] ) -> Tuple[Tensor, Union[Tensor, List[Dict[str, Tensor]]]]: """Move input and target to GPU""" if isinstance(self.gpu, int): if self.gpu >= torch.cuda.device_count(): raise ValueError("Invalid device index") return self._to_cuda(x, target) # type: ignore[arg-type] return x, target
@staticmethod def _to_cuda(x: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Move input and target to GPU""" x = x.cuda(non_blocking=True) target = target.cuda(non_blocking=True) return x, target def _backprop_step(self, loss: Tensor) -> None: # Backpropate the loss self._grad_count += 1 if self.amp: # Backprop self.scaler.scale(loss).backward() if self._grad_count == self.gradient_acc: # Safeguard for Gradient explosion if isinstance(self.grad_clip, float): self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() self._grad_count = 0 else: # Backprop loss.backward() if self._grad_count == self.gradient_acc: # Safeguard for Gradient explosion if isinstance(self.grad_clip, float): nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() self.optimizer.zero_grad() self._grad_count = 0 def _get_loss(self, x: Tensor, target: Tensor, return_logits: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]: # AMP if self.amp: with torch.cuda.amp.autocast(): # Forward out = self.model(x) # Loss computation loss = cast(Tensor, self.criterion(out, target)) if return_logits: return loss, out return loss # Forward out = self.model(x) loss = cast(Tensor, self.criterion(out, target)) if return_logits: return loss, out return loss def _set_params(self, norm_weight_decay: Optional[float] = None) -> None: if not any(p.requires_grad for p in self.model.parameters()): raise AssertionError("All parameters are frozen") if norm_weight_decay is None: self._params = [p for p in self.model.parameters() if p.requires_grad], [] else: self._params = split_normalization_params(self.model) def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None: """Reset the target params of the optimizer""" self.optimizer.defaults["lr"] = lr self.optimizer.state = defaultdict(dict) self.optimizer.param_groups = [] self._set_params(norm_weight_decay) # Split it if norm layers needs custom WD if norm_weight_decay is None: self.optimizer.add_param_group({"params": self._params[0]}) else: wd_groups = [norm_weight_decay, self.optimizer.defaults.get("weight_decay", 0)] for _params, _wd in zip(self._params, wd_groups): if len(_params) > 0: self.optimizer.add_param_group({"params": _params, "weight_decay": _wd}) self.optimizer.zero_grad() @torch.inference_mode() def evaluate(self): # type: ignore[no-untyped-def] # noqa: ANN201 raise NotImplementedError @staticmethod def _eval_metrics_str(eval_metrics) -> str: # type: ignore[no-untyped-def] # noqa: ANN001 raise NotImplementedError def _reset_scheduler(self, lr: float, num_epochs: int, sched_type: str = "onecycle", **kwargs: Any) -> None: self.scheduler: LRScheduler if sched_type == "onecycle": self.scheduler = OneCycleLR(self.optimizer, lr, num_epochs * len(self.train_loader), **kwargs) elif sched_type == "cosine": self.scheduler = CosineAnnealingLR(self.optimizer, num_epochs * len(self.train_loader), **kwargs) else: raise ValueError(f"The following scheduler type is not supported: {sched_type}")
[docs] def fit_n_epochs( self, num_epochs: int, lr: float, freeze_until: Optional[str] = None, sched_type: str = "onecycle", norm_weight_decay: Optional[float] = None, **kwargs: Any, ) -> None: """Train the model for a given number of epochs. Args: num_epochs (int): number of epochs to train lr (float): learning rate to be used by the scheduler freeze_until (str, optional): last layer to freeze sched_type (str, optional): type of scheduler to use norm_weight_decay (float, optional): weight decay to apply to normalization parameters **kwargs: keyword args passed to the schedulers """ freeze_model(self.model.train(), freeze_until) # Update param groups & LR self._reset_opt(lr, norm_weight_decay) # Scheduler self._reset_scheduler(lr, num_epochs, sched_type, **kwargs) if self.amp: self.scaler = torch.cuda.amp.GradScaler() mb = master_bar(range(num_epochs)) for _ in mb: self._fit_epoch(mb) eval_metrics = self.evaluate() # master bar mb.main_bar.comment = f"Epoch {self.epoch}/{self.start_epoch + num_epochs}" mb.write(f"Epoch {self.epoch}/{self.start_epoch + num_epochs} - {self._eval_metrics_str(eval_metrics)}") if eval_metrics["val_loss"] < self.min_loss: print( # noqa: T201 f"Validation loss decreased {self.min_loss:.4} --> " f"{eval_metrics['val_loss']:.4}: saving state..." ) self.min_loss = eval_metrics["val_loss"] self.save(self.output_file) if self.on_epoch_end is not None: self.on_epoch_end(eval_metrics)
[docs] def find_lr( self, freeze_until: Optional[str] = None, start_lr: float = 1e-7, end_lr: float = 1, norm_weight_decay: Optional[float] = None, num_it: int = 100, ) -> None: """Gridsearch the optimal learning rate for the training as described in `"Cyclical Learning Rates for Training Neural Networks" <https://arxiv.org/pdf/1506.01186.pdf>`_. Args: freeze_until (str, optional): last layer to freeze start_lr (float, optional): initial learning rate end_lr (float, optional): final learning rate norm_weight_decay (float, optional): weight decay to apply to normalization parameters num_it (int, optional): number of iterations to perform """ if num_it > len(self.train_loader): raise ValueError("the value of `num_it` needs to be lower than the number of available batches") freeze_model(self.model.train(), freeze_until) # Update param groups & LR self._reset_opt(start_lr, norm_weight_decay) gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) scheduler = MultiplicativeLR(self.optimizer, lambda step: gamma) self.lr_recorder = [start_lr * gamma**idx for idx in range(num_it)] self.loss_recorder = [] if self.amp: self.scaler = torch.cuda.amp.GradScaler() for batch_idx, (x, target) in enumerate(self.train_loader): x, target = self.to_cuda(x, target) # Forward batch_loss: Tensor = self._get_loss(x, target) # type: ignore[assignment] self._backprop_step(batch_loss) # Update LR scheduler.step() # Record if torch.isnan(batch_loss) or torch.isinf(batch_loss): if batch_idx == 0: raise ValueError("loss value is NaN or inf.") break self.loss_recorder.append(batch_loss.item()) # Stop after the number of iterations if batch_idx + 1 == num_it: break self.lr_recorder = self.lr_recorder[: len(self.loss_recorder)]
[docs] def plot_recorder(self, beta: float = 0.95, **kwargs: Any) -> None: """Display the results of the LR grid search Args: beta (float, optional): smoothing factor kwargs: keyword args of matplotlib.pyplot.show """ if len(self.lr_recorder) != len(self.loss_recorder) or len(self.lr_recorder) == 0: raise AssertionError("Please run the `lr_find` method first") # Exp moving average of loss smoothed_losses = [] avg_loss = 0.0 for idx, loss in enumerate(self.loss_recorder): avg_loss = beta * avg_loss + (1 - beta) * loss smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) # Properly rescale Y-axis data_slice = slice( min(len(self.loss_recorder) // 10, 10), -min(len(self.loss_recorder) // 20, 5) if len(self.loss_recorder) >= 20 else len(self.loss_recorder), ) vals: np.ndarray = np.array(smoothed_losses[data_slice]) min_idx = vals.argmin() max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() delta = max_val - vals[min_idx] plt.plot(self.lr_recorder[data_slice], smoothed_losses[data_slice]) plt.xscale("log") plt.xlabel("Learning Rate") plt.ylabel("Training loss") plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) plt.grid(True, linestyle="--", axis="x") plt.show(**kwargs)
[docs] def check_setup( self, freeze_until: Optional[str] = None, lr: float = 3e-4, norm_weight_decay: Optional[float] = None, num_it: int = 100, **kwargs: Any, ) -> None: """Check whether you can overfit one batch Args: freeze_until (str, optional): last layer to freeze lr (float, optional): learning rate to be used for training norm_weight_decay (float, optional): weight decay to apply to normalization parameters num_it (int, optional): number of iterations to perform kwargs: keyword args of matplotlib.pyplot.show """ freeze_model(self.model.train(), freeze_until) # Update param groups & LR self._reset_opt(lr, norm_weight_decay) x, target = next(iter(self.train_loader)) x, target = self.to_cuda(x, target) _losses = [] if self.amp: self.scaler = torch.cuda.amp.GradScaler() for _ in range(num_it): # Forward batch_loss: Tensor = self._get_loss(x, target) # type: ignore[assignment] # Backprop self._backprop_step(batch_loss) if torch.isnan(batch_loss) or torch.isinf(batch_loss): raise ValueError("loss value is NaN or inf.") _losses.append(batch_loss.item()) plt.plot(np.arange(len(_losses)), _losses) plt.xlabel("Optimization steps") plt.ylabel("Training loss") plt.grid(True, linestyle="--", axis="x") plt.show(**kwargs)