Source code for holocron.trainer.classification

# Copyright (C) 2019-2022, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from typing import Dict, Tuple, Union

import torch
from torch import Tensor

from .core import Trainer

__all__ = ["ClassificationTrainer", "BinaryClassificationTrainer"]


[docs] class ClassificationTrainer(Trainer): """Image classification trainer class Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (torch.nn.Module): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved amp (bool, optional): whether to use automatic mixed precision skip_nan_loss (bool, optional): whether the optimizer step should be skipped when the loss is NaN on_epoch_end (Callable[[Dict[str, float]], Any]): callback triggered at the end of an epoch """
[docs] @torch.inference_mode() def evaluate(self) -> Dict[str, float]: """Evaluate the model on the validation set Returns: dict: evaluation metrics """ self.model.eval() val_loss, top1, top5, num_samples, num_valid_batches = 0.0, 0, 0, 0, 0 for x, target in self.val_loader: x, target = self.to_cuda(x, target) _loss, out = self._get_loss(x, target, return_logits=True) # Safeguard for NaN loss if not torch.isnan(_loss) and not torch.isinf(_loss): val_loss += _loss.item() num_valid_batches += 1 pred = out.topk(5, dim=1)[1] if out.shape[1] >= 5 else out.argmax(dim=1, keepdim=True) correct = pred.eq(target.view(-1, 1).expand_as(pred)) top1 += correct[:, 0].sum().item() if out.shape[1] >= 5: top5 += correct.any(dim=1).sum().item() num_samples += x.shape[0] val_loss /= num_valid_batches return dict(val_loss=val_loss, acc1=top1 / num_samples, acc5=top5 / num_samples)
@staticmethod def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str: return ( f"Validation loss: {eval_metrics['val_loss']:.4} " f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})" )
[docs] class BinaryClassificationTrainer(Trainer): """Image binary classification trainer class Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (torch.nn.Module): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved amp (bool, optional): whether to use automatic mixed precision """ def _get_loss( self, x: torch.Tensor, target: torch.Tensor, return_logits: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: # In case target are stored as long target = target.to(dtype=x.dtype) # AMP if self.amp: with torch.cuda.amp.autocast(): # type: ignore[attr-defined] # Forward out = self.model(x) # Loss computation loss = self.criterion(out, target.view_as(out)) if return_logits: return loss, out return loss # Forward out = self.model(x) loss = self.criterion(out, target.view_as(out)) if return_logits: return loss, out return loss
[docs] @torch.inference_mode() def evaluate(self) -> Dict[str, float]: """Evaluate the model on the validation set Returns: dict: evaluation metrics """ self.model.eval() val_loss, top1, num_samples, num_valid_batches = 0.0, 0.0, 0, 0 for x, target in self.val_loader: x, target = self.to_cuda(x, target) _loss, out = self._get_loss(x, target, return_logits=True) # Safeguard for NaN loss if not torch.isnan(_loss) and not torch.isinf(_loss): val_loss += _loss.item() num_valid_batches += 1 top1 += torch.sum((target.view_as(out) >= 0.5) == (torch.sigmoid(out) >= 0.5)).item() / out[0].numel() num_samples += x.shape[0] val_loss /= num_valid_batches return dict(val_loss=val_loss, acc=top1 / num_samples)
@staticmethod def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str: return f"Validation loss: {eval_metrics['val_loss']:.4} " f"(Acc: {eval_metrics['acc']:.2%})"