# Copyright (C) 2019-2022, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
import math
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import ConsoleMasterBar
from torch import Tensor, nn
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR # type: ignore[attr-defined]
from torch.utils.data import DataLoader
from .utils import freeze_bn, freeze_model, split_normalization_params
ParamSeq = Sequence[torch.nn.Parameter] # type: ignore[name-defined]
__all__ = ["Trainer"]
[docs]
class Trainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
gpu: Optional[int] = None,
output_file: str = "./checkpoint.pth",
amp: bool = False,
skip_nan_loss: bool = False,
nan_tolerance: int = 5,
on_epoch_end: Optional[Callable[[Dict[str, float]], Any]] = None,
) -> None:
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.criterion = criterion
self.optimizer = optimizer
self.amp = amp
self.scaler: torch.cuda.amp.grad_scaler.GradScaler
self.on_epoch_end = on_epoch_end
self.skip_nan_loss = skip_nan_loss
self.nan_tolerance = nan_tolerance
# Output file
self.output_file = output_file
# Initialize
self.step = 0
self.start_epoch = 0
self.epoch = 0
self.min_loss = math.inf
self.gpu = gpu
self._params: Tuple[ParamSeq, ParamSeq] = ([], [])
self.lr_recorder: List[float] = []
self.loss_recorder: List[float] = []
self.set_device(gpu)
self._reset_opt(self.optimizer.defaults["lr"])
[docs]
def set_device(self, gpu: Optional[int] = None) -> None:
"""Move tensor objects to the target GPU
Args:
gpu: index of the target GPU device
"""
if isinstance(gpu, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if gpu >= torch.cuda.device_count():
raise ValueError("Invalid device index")
torch.cuda.set_device(gpu)
self.model = self.model.cuda()
if isinstance(self.criterion, torch.nn.Module):
self.criterion = self.criterion.cuda()
[docs]
def save(self, output_file: str) -> None:
"""Save a trainer checkpoint
Args:
output_file: destination file path
"""
torch.save(
dict(
epoch=self.epoch,
step=self.step,
min_loss=self.min_loss,
optimizer=self.optimizer.state_dict(),
model=self.model.state_dict(),
),
output_file,
_use_new_zipfile_serialization=False,
)
[docs]
def load(self, state: Dict[str, Any]) -> None:
"""Resume from a trainer state
Args:
state (dict): checkpoint dictionary
"""
self.start_epoch = state["epoch"]
self.epoch = self.start_epoch
self.step = state["step"]
self.min_loss = state["min_loss"]
self.optimizer.load_state_dict(state["optimizer"])
self.model.load_state_dict(state["model"])
def _fit_epoch(self, mb: ConsoleMasterBar) -> None:
"""Fit a single epoch
Args:
mb (fastprogress.master_bar): primary progress bar
"""
self.model = freeze_bn(self.model.train())
nan_cnt = 0
pb = progress_bar(self.train_loader, parent=mb)
for x, target in pb:
x, target = self.to_cuda(x, target)
# Forward
batch_loss: Tensor = self._get_loss(x, target) # type: ignore[assignment]
# Backprop
if not self.skip_nan_loss or torch.isfinite(batch_loss):
nan_cnt = 0
self._backprop_step(batch_loss)
else:
nan_cnt += 1
if nan_cnt > self.nan_tolerance:
raise ValueError(f"loss value has been NaN or inf for more than {self.nan_tolerance} steps.")
# Update LR
self.scheduler.step()
pb.comment = f"Training loss: {batch_loss.item():.4}"
self.step += 1
self.epoch += 1
[docs]
def to_cuda(
self, x: Tensor, target: Union[Tensor, List[Dict[str, Tensor]]]
) -> Tuple[Tensor, Union[Tensor, List[Dict[str, Tensor]]]]:
"""Move input and target to GPU"""
if isinstance(self.gpu, int):
if self.gpu >= torch.cuda.device_count():
raise ValueError("Invalid device index")
return self._to_cuda(x, target) # type: ignore[arg-type]
return x, target
@staticmethod
def _to_cuda(x: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Move input and target to GPU"""
x = x.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
return x, target
def _backprop_step(self, loss: Tensor) -> None:
# Clean gradients
self.optimizer.zero_grad()
# Backpropate the loss
if self.amp:
self.scaler.scale(loss).backward()
# Update the params
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
# Update the params
self.optimizer.step()
def _get_loss(self, x: Tensor, target: Tensor, return_logits: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
# AMP
if self.amp:
with torch.cuda.amp.autocast(): # type: ignore[attr-defined]
# Forward
out = self.model(x)
# Loss computation
loss = self.criterion(out, target)
if return_logits:
return loss, out
return loss
# Forward
out = self.model(x)
loss = self.criterion(out, target)
if return_logits:
return loss, out
return loss
def _set_params(self, norm_weight_decay: Optional[float] = None) -> None:
if not any(p.requires_grad for p in self.model.parameters()):
raise AssertionError("All parameters are frozen")
if norm_weight_decay is None:
self._params = [p for p in self.model.parameters() if p.requires_grad], []
else:
self._params = split_normalization_params(self.model)
def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None:
"""Reset the target params of the optimizer"""
self.optimizer.defaults["lr"] = lr
self.optimizer.state = defaultdict(dict)
self.optimizer.param_groups = []
self._set_params(norm_weight_decay)
# Split it if norm layers needs custom WD
if norm_weight_decay is None:
self.optimizer.add_param_group(dict(params=self._params[0]))
else:
wd_groups = [norm_weight_decay, self.optimizer.defaults.get("weight_decay", 0)]
for _params, _wd in zip(self._params, wd_groups):
if len(_params) > 0:
self.optimizer.add_param_group(dict(params=_params, weight_decay=_wd))
@torch.inference_mode()
def evaluate(self):
raise NotImplementedError
@staticmethod
def _eval_metrics_str(eval_metrics):
raise NotImplementedError
def _reset_scheduler(self, lr: float, num_epochs: int, sched_type: str = "onecycle") -> None:
if sched_type == "onecycle":
self.scheduler = OneCycleLR(self.optimizer, lr, num_epochs * len(self.train_loader))
elif sched_type == "cosine":
self.scheduler = CosineAnnealingLR(self.optimizer, num_epochs * len(self.train_loader), eta_min=lr / 25e4)
else:
raise ValueError(f"The following scheduler type is not supported: {sched_type}")
[docs]
def fit_n_epochs(
self,
num_epochs: int,
lr: float,
freeze_until: Optional[str] = None,
sched_type: str = "onecycle",
norm_weight_decay: Optional[float] = None,
) -> None:
"""Train the model for a given number of epochs
Args:
num_epochs (int): number of epochs to train
lr (float): learning rate to be used by the scheduler
freeze_until (str, optional): last layer to freeze
sched_type (str, optional): type of scheduler to use
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
"""
self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr, norm_weight_decay)
# Scheduler
self._reset_scheduler(lr, num_epochs, sched_type)
if self.amp:
self.scaler = torch.cuda.amp.GradScaler() # type: ignore[attr-defined]
mb = master_bar(range(num_epochs))
for _ in mb:
self._fit_epoch(mb)
eval_metrics = self.evaluate()
# master bar
mb.main_bar.comment = f"Epoch {self.epoch}/{self.start_epoch + num_epochs}"
mb.write(f"Epoch {self.epoch}/{self.start_epoch + num_epochs} - " f"{self._eval_metrics_str(eval_metrics)}")
if eval_metrics["val_loss"] < self.min_loss:
print(
f"Validation loss decreased {self.min_loss:.4} --> "
f"{eval_metrics['val_loss']:.4}: saving state..."
)
self.min_loss = eval_metrics["val_loss"]
self.save(self.output_file)
if self.on_epoch_end is not None:
self.on_epoch_end(eval_metrics)
[docs]
def find_lr(
self,
freeze_until: Optional[str] = None,
start_lr: float = 1e-7,
end_lr: float = 1,
norm_weight_decay: Optional[float] = None,
num_it: int = 100,
) -> None:
"""Gridsearch the optimal learning rate for the training
Args:
freeze_until (str, optional): last layer to freeze
start_lr (float, optional): initial learning rate
end_lr (float, optional): final learning rate
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
num_it (int, optional): number of iterations to perform
"""
if num_it > len(self.train_loader):
raise ValueError("the value of `num_it` needs to be lower than the number of available batches")
self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(start_lr, norm_weight_decay)
gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
scheduler = MultiplicativeLR(self.optimizer, lambda step: gamma)
self.lr_recorder = [start_lr * gamma**idx for idx in range(num_it)]
self.loss_recorder = []
if self.amp:
self.scaler = torch.cuda.amp.GradScaler() # type: ignore[attr-defined]
for batch_idx, (x, target) in enumerate(self.train_loader):
x, target = self.to_cuda(x, target)
# Forward
batch_loss: Tensor = self._get_loss(x, target) # type: ignore[assignment]
self._backprop_step(batch_loss)
# Update LR
scheduler.step()
# Record
if torch.isnan(batch_loss) or torch.isinf(batch_loss):
if batch_idx == 0:
raise ValueError("loss value is NaN or inf.")
else:
break
self.loss_recorder.append(batch_loss.item())
# Stop after the number of iterations
if batch_idx + 1 == num_it:
break
self.lr_recorder = self.lr_recorder[: len(self.loss_recorder)]
[docs]
def plot_recorder(self, beta: float = 0.95, **kwargs: Any) -> None:
"""Display the results of the LR grid search
Args:
beta (float, optional): smoothing factor
"""
if len(self.lr_recorder) != len(self.loss_recorder) or len(self.lr_recorder) == 0:
raise AssertionError("Please run the `lr_find` method first")
# Exp moving average of loss
smoothed_losses = []
avg_loss = 0.0
for idx, loss in enumerate(self.loss_recorder):
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1)))
# Properly rescale Y-axis
data_slice = slice(
min(len(self.loss_recorder) // 10, 10),
-min(len(self.loss_recorder) // 20, 5) if len(self.loss_recorder) >= 20 else len(self.loss_recorder),
)
vals: np.ndarray = np.array(smoothed_losses[data_slice])
min_idx = vals.argmin()
max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() # type: ignore[misc]
delta = max_val - vals[min_idx]
plt.plot(self.lr_recorder[data_slice], smoothed_losses[data_slice])
plt.xscale("log")
plt.xlabel("Learning Rate")
plt.ylabel("Training loss")
plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta)
plt.grid(True, linestyle="--", axis="x")
plt.show(**kwargs)
[docs]
def check_setup(
self,
freeze_until: Optional[str] = None,
lr: float = 3e-4,
norm_weight_decay: Optional[float] = None,
num_it: int = 100,
) -> bool:
"""Check whether you can overfit one batch
Args:
freeze_until (str, optional): last layer to freeze
lr (float, optional): learning rate to be used for training
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
num_it (int, optional): number of iterations to perform
"""
self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr, norm_weight_decay)
x, target = next(iter(self.train_loader))
x, target = self.to_cuda(x, target)
_losses = []
if self.amp:
self.scaler = torch.cuda.amp.GradScaler() # type: ignore[attr-defined]
for _ in range(num_it):
# Forward
batch_loss: Tensor = self._get_loss(x, target) # type: ignore[assignment]
# Backprop
self._backprop_step(batch_loss)
if torch.isnan(batch_loss) or torch.isinf(batch_loss):
raise ValueError("loss value is NaN or inf.")
_losses.append(batch_loss.item())
return _losses[-1] < _losses[0]