Source code for holocron.trainer.classification

# Copyright (C) 2019-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import math
from typing import Any, Dict, Sequence, Tuple, Union, cast

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm

from .core import Trainer

__all__ = ["BinaryClassificationTrainer", "ClassificationTrainer"]


[docs] class ClassificationTrainer(Trainer): """Image classification trainer class. Args: model: model to train train_loader: training loader val_loader: validation loader criterion: loss criterion optimizer: parameter optimizer gpu: index of the GPU to use output_file: path where checkpoints will be saved amp: whether to use automatic mixed precision skip_nan_loss: whether the optimizer step should be skipped when the loss is NaN nan_tolerance: number of consecutive batches with NaN loss before stopping the training gradient_acc: number of batches to accumulate the gradient of before performing the update step gradient_clip: the gradient clip value on_epoch_end: callback triggered at the end of an epoch """ is_binary: bool = False
[docs] @torch.inference_mode() def evaluate(self) -> Dict[str, float]: """Evaluate the model on the validation set Returns: dict: evaluation metrics """ self.model.eval() val_loss, top1, top5, num_samples, num_valid_batches = 0.0, 0, 0, 0, 0 for x, target in self.val_loader: x, target = self.to_cuda(x, target) loss, out = self._get_loss(x, target, return_logits=True) # Safeguard for NaN loss if not torch.isnan(loss) and not torch.isinf(loss): val_loss += loss.item() num_valid_batches += 1 pred = out.topk(5, dim=1)[1] if out.shape[1] >= 5 else out.argmax(dim=1, keepdim=True) correct = pred.eq(target.view(-1, 1).expand_as(pred)) top1 += cast(int, correct[:, 0].sum().item()) if out.shape[1] >= 5: top5 += cast(int, correct.any(dim=1).sum().item()) num_samples += x.shape[0] val_loss /= num_valid_batches return {"val_loss": val_loss, "acc1": top1 / num_samples, "acc5": top5 / num_samples}
@staticmethod def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str: return ( f"Validation loss: {eval_metrics['val_loss']:.4} " f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})" ) @torch.inference_mode() def plot_top_losses( self, mean: Tuple[float, float, float], std: Tuple[float, float, float], classes: Union[Sequence[str], None] = None, num_samples: int = 12, **kwargs: Any, ) -> None: # Record loss, prob, target, image losses = np.zeros(num_samples, dtype=np.float32) preds = np.zeros(num_samples, dtype=int) probs = np.zeros(num_samples, dtype=np.float32) targets = np.zeros(num_samples, dtype=np.float32 if self.is_binary else int) images = [None] * num_samples # Switch to unreduced loss reduction = self.criterion.reduction self.criterion.reduction = "none" # type: ignore[assignment] self.model.eval() train_iter = iter(self.train_loader) for x, target in tqdm(train_iter): x, target = self.to_cuda(x, target) # Forward batch_loss, logits = self._get_loss(x, target, return_logits=True) # Binary if self.is_binary: batch_loss = batch_loss.squeeze(1) probs_ = torch.sigmoid(logits.squeeze(1)) else: probs_ = torch.softmax(logits, 1).max(dim=1).values if torch.any(batch_loss > losses.min()): idcs = np.concatenate((losses, batch_loss.cpu().numpy())).argsort()[-num_samples:] kept_idcs = [idx for idx in idcs if idx < num_samples] added_idcs = [idx - num_samples for idx in idcs if idx >= num_samples] # Update losses = np.concatenate((losses[kept_idcs], batch_loss.cpu().numpy()[added_idcs])) probs = np.concatenate((probs[kept_idcs], probs_.cpu().numpy())) if not self.is_binary: preds = np.concatenate((preds[kept_idcs], logits[added_idcs].argmax(dim=1).cpu().numpy())) targets = np.concatenate((targets[kept_idcs], target[added_idcs].cpu().numpy())) imgs = x[added_idcs].cpu() * torch.tensor(std).view(-1, 1, 1) imgs += torch.tensor(mean).view(-1, 1, 1) images = [images[idx] for idx in kept_idcs] + [to_pil_image(img) for img in imgs] self.criterion.reduction = reduction if not self.is_binary and classes is None: raise AssertionError("arg 'classes' must be specified for multi-class classification") # Final sort idcs_ = losses.argsort()[::-1] losses, preds, probs, targets = losses[idcs_], preds[idcs_], probs[idcs_], targets[idcs_] images = [images[idx] for idx in idcs_] # Plot it num_cols = 4 num_rows = math.ceil(num_samples / num_cols) _, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5)) for idx, (img, pred, prob, target, loss) in enumerate(zip(images, preds, probs, targets, losses)): row = int(idx / num_cols) col = idx - num_cols * row axes[row][col].imshow(img) # Loss, prob, target if self.is_binary: axes[row][col].title.set_text(f"{loss:.3} / {prob:.2} / {target:.2}") # Loss, pred (prob), target else: axes[row][col].title.set_text( f"{loss:.3} / {classes[pred]} ({prob:.1%}) / {classes[target]}" # type: ignore[index] ) axes[row][col].axis("off") plt.show(**kwargs)
[docs] class BinaryClassificationTrainer(ClassificationTrainer): """Image binary classification trainer class. Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (torch.nn.Module): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved amp (bool, optional): whether to use automatic mixed precision """ is_binary: bool = True def _get_loss( self, x: torch.Tensor, target: torch.Tensor, return_logits: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # In case target are stored as long target = target.to(dtype=x.dtype) # AMP if self.amp: with torch.cuda.amp.autocast(): # Forward out = self.model(x) # Loss computation loss = self.criterion(out, target.view_as(out)) if return_logits: return loss, out return loss # Forward out = self.model(x) loss = self.criterion(out, target.view_as(out)) if return_logits: return loss, out return loss
[docs] @torch.inference_mode() def evaluate(self) -> Dict[str, float]: """Evaluate the model on the validation set Returns: dict: evaluation metrics """ self.model.eval() val_loss, top1, num_samples, num_valid_batches = 0.0, 0.0, 0, 0 for x, target in self.val_loader: x, target = self.to_cuda(x, target) loss, out = self._get_loss(x, target, return_logits=True) # Safeguard for NaN loss if not torch.isnan(loss) and not torch.isinf(loss): val_loss += loss.item() num_valid_batches += 1 top1 += torch.sum((target.view_as(out) >= 0.5) == (torch.sigmoid(out) >= 0.5)).item() / out[0].numel() num_samples += x.shape[0] val_loss /= num_valid_batches return {"val_loss": val_loss, "acc": top1 / num_samples}
@staticmethod def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str: return f"Validation loss: {eval_metrics['val_loss']:.4} (Acc: {eval_metrics['acc']:.2%})"