# Copyright (C) 2019-2024, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from torchvision.ops.boxes import box_iou
from .core import Trainer
__all__ = ["DetectionTrainer"]
def assign_iou(gt_boxes: Tensor, pred_boxes: Tensor, iou_threshold: float = 0.5) -> Tuple[List[int], List[int]]:
"""Assigns boxes by IoU"""
iou = box_iou(gt_boxes, pred_boxes)
iou = iou.max(dim=1)
gt_kept = iou.values >= iou_threshold
assign_unique = torch.unique(iou.indices[gt_kept])
# Filter
if iou.indices[gt_kept].shape[0] == assign_unique.shape[0]:
return torch.arange(gt_boxes.shape[0])[gt_kept], iou.indices[gt_kept] # type: ignore[return-value]
gt_indices, pred_indices = [], []
for pred_idx in assign_unique:
selection = iou.values[gt_kept][iou.indices[gt_kept] == pred_idx].argmax()
return gt_indices, pred_indices # type: ignore[return-value]
class DetectionTrainer(Trainer):
"""Object detection trainer class.
model: model to train
train_loader: training loader
val_loader: validation loader
criterion: loss criterion
optimizer: parameter optimizer
gpu: index of the GPU to use
output_file: path where checkpoints will be saved
amp: whether to use automatic mixed precision
skip_nan_loss: whether the optimizer step should be skipped when the loss is NaN
nan_tolerance: number of consecutive batches with NaN loss before stopping the training
gradient_acc: number of batches to accumulate the gradient of before performing the update step
gradient_clip: the gradient clip value
on_epoch_end: callback triggered at the end of an epoch
def _to_cuda( # type: ignore[override]
x: List[Tensor], target: List[Dict[str, Tensor]]
) -> Tuple[List[Tensor], List[Dict[str, Tensor]]]:
"""Move input and target to GPU"""
x = [_x.cuda(non_blocking=True) for _x in x]
target = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in target]
return x, target
def _get_loss(self, x: List[Tensor], target: List[Dict[str, Tensor]]) -> Tensor: # type: ignore[override]
if self.amp:
with torch.cuda.amp.autocast():
# Forward & loss computation
loss_dict = self.model(x, target)
return sum(loss_dict.values())
# Forward & loss computation
loss_dict = self.model(x, target)
return sum(loss_dict.values())
def _eval_metrics_str(eval_metrics: Dict[str, Optional[float]]) -> str:
loc_str = f"{eval_metrics['loc_err']:.2%}" if isinstance(eval_metrics["loc_err"], float) else "N/A"
clf_str = f"{eval_metrics['clf_err']:.2%}" if isinstance(eval_metrics["clf_err"], float) else "N/A"
det_str = f"{eval_metrics['det_err']:.2%}" if isinstance(eval_metrics["det_err"], float) else "N/A"
return f"Loc error: {loc_str} | Clf error: {clf_str} | Det error: {det_str}"
def evaluate(self, iou_threshold: float = 0.5) -> Dict[str, Optional[float]]:
"""Evaluate the model on the validation set.
iou_threshold (float, optional): IoU threshold for pair assignment
dict: evaluation metrics
loc_assigns = 0
correct, clf_error, loc_fn, loc_fp, num_samples = 0, 0, 0, 0, 0
for x, target in self.val_loader:
x, target = self.to_cuda(x, target)
if self.amp:
with torch.cuda.amp.autocast():
detections = self.model(x)
detections = self.model(x)
for dets, t in zip(detections, target):
if t["boxes"].shape[0] > 0 and dets["boxes"].shape[0] > 0:
gt_indices, pred_indices = assign_iou(t["boxes"], dets["boxes"], iou_threshold)
loc_assigns += len(gt_indices)
correct_ = (t["labels"][gt_indices] == dets["labels"][pred_indices]).sum().item()
gt_indices, pred_indices = [], []
correct_ = 0
correct += correct_
clf_error += len(gt_indices) - correct_
loc_fn += t["boxes"].shape[0] - len(gt_indices)
loc_fp += dets["boxes"].shape[0] - len(pred_indices)
num_samples += sum(t["boxes"].shape[0] for t in target)
nb_preds = num_samples - loc_fn + loc_fp
# Localization
loc_err = 1 - 2 * loc_assigns / (nb_preds + num_samples) if nb_preds + num_samples > 0 else None
# Classification
clf_err = 1 - correct / loc_assigns if loc_assigns > 0 else None
# End-to-end
det_err = 1 - 2 * correct / (nb_preds + num_samples) if nb_preds + num_samples > 0 else None
return {"loc_err": loc_err, "clf_err": clf_err, "det_err": det_err, "val_loss": loc_err}