# Copyright (C) 2019-2022, François-Guillaume Fernandez.
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
from typing import Dict
import torch
from .core import Trainer
__all__ = ['ClassificationTrainer', 'BinaryClassificationTrainer']
[docs]
class ClassificationTrainer(Trainer):
"""Image classification trainer class
Args:
model (torch.nn.Module): model to train
train_loader (torch.utils.data.DataLoader): training loader
val_loader (torch.utils.data.DataLoader): validation loader
criterion (torch.nn.Module): loss criterion
optimizer (torch.optim.Optimizer): parameter optimizer
gpu (int, optional): index of the GPU to use
output_file (str, optional): path where checkpoints will be saved
amp (bool, optional): whether to use automatic mixed precision
skip_nan_loss (bool, optional): whether the optimizer step should be skipped when the loss is NaN
on_epoch_end (Callable[[Dict[str, float]], Any]): callback triggered at the end of an epoch
"""
[docs]
@torch.inference_mode()
def evaluate(self) -> Dict[str, float]:
"""Evaluate the model on the validation set
Returns:
dict: evaluation metrics
"""
self.model.eval()
val_loss, top1, top5, num_samples, num_valid_batches = 0., 0, 0, 0, 0
for x, target in self.val_loader:
x, target = self.to_cuda(x, target)
if self.amp:
with torch.cuda.amp.autocast():
# Forward
out = self.model(x)
# Loss computation
_loss = self.criterion(out, target)
else:
# Forward
out = self.model(x)
# Loss computation
_loss = self.criterion(out, target)
# Safeguard for NaN loss
if not torch.isnan(_loss) and not torch.isinf(_loss):
val_loss += _loss.item()
num_valid_batches += 1
pred = out.topk(5, dim=1)[1] if out.shape[1] >= 5 else out.argmax(dim=1, keepdim=True)
correct = pred.eq(target.view(-1, 1).expand_as(pred))
top1 += correct[:, 0].sum().item()
if out.shape[1] >= 5:
top5 += correct.any(dim=1).sum().item()
num_samples += x.shape[0]
val_loss /= num_valid_batches
return dict(val_loss=val_loss, acc1=top1 / num_samples, acc5=top5 / num_samples)
@staticmethod
def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str:
return (f"Validation loss: {eval_metrics['val_loss']:.4} "
f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})")
[docs]
class BinaryClassificationTrainer(Trainer):
"""Image binary classification trainer class
Args:
model (torch.nn.Module): model to train
train_loader (torch.utils.data.DataLoader): training loader
val_loader (torch.utils.data.DataLoader): validation loader
criterion (torch.nn.Module): loss criterion
optimizer (torch.optim.Optimizer): parameter optimizer
gpu (int, optional): index of the GPU to use
output_file (str, optional): path where checkpoints will be saved
amp (bool, optional): whether to use automatic mixed precision
"""
[docs]
@torch.inference_mode()
def evaluate(self) -> Dict[str, float]:
"""Evaluate the model on the validation set
Returns:
dict: evaluation metrics
"""
self.model.eval()
val_loss, top1, num_samples, num_valid_batches = 0., 0, 0, 0
for x, target in self.val_loader:
x, target = self.to_cuda(x, target)
if self.amp:
with torch.cuda.amp.autocast():
# Forward
out = self.model(x)
# Apply sigmoid
out = torch.sigmoid(out)
# Loss computation
_loss = self.criterion(out, target)
else:
# Forward
out = self.model(x)
# Apply sigmoid
out = torch.sigmoid(out)
# Loss computation
_loss = self.criterion(out, target)
# Safeguard for NaN loss
if not torch.isnan(_loss) and not torch.isinf(_loss):
val_loss += _loss.item()
num_valid_batches += 1
top1 += int(torch.sum((target >= 0.5) == (out >= 0.5)).item())
num_samples += x.shape[0]
val_loss /= num_valid_batches
return dict(val_loss=val_loss, acc=top1 / num_samples)
@staticmethod
def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str:
return (f"Validation loss: {eval_metrics['val_loss']:.4} "
f"(Acc: {eval_metrics['acc']:.2%})")