import math
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR, MultiplicativeLR # type: ignore[attr-defined]
from torchvision.ops.boxes import box_iou
from fastprogress.fastprogress import ConsoleMasterBar
from fastprogress import master_bar, progress_bar
from typing import Optional, Dict, Any, Union, List, Tuple
from contiguous_params import ContiguousParams
from .utils import freeze_bn, freeze_model
__all__ = ['Trainer', 'ClassificationTrainer', 'SegmentationTrainer', 'DetectionTrainer']
[docs]
class Trainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
gpu: Optional[int] = None,
output_file: str = './checkpoint.pth'
) -> None:
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.criterion = criterion
self.optimizer = optimizer
# Output file
self.output_file = output_file
# Initialize
self.step = 0
self.start_epoch = 0
self.epoch = 0
self.min_loss = math.inf
self.gpu = gpu
self._params: Optional[ContiguousParams] = None
self.lr_recorder: List[float] = []
self.loss_recorder: List[float] = []
self.set_device(gpu)
self._reset_opt(self.optimizer.defaults['lr'])
[docs]
def set_device(self, gpu: Optional[int] = None) -> None:
"""Move tensor objects to the target GPU
Args:
gpu: index of the target GPU device
"""
if isinstance(gpu, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if gpu >= torch.cuda.device_count():
raise ValueError("Invalid device index")
torch.cuda.set_device(gpu)
self.model = self.model.cuda()
if isinstance(self.criterion, torch.nn.Module):
self.criterion = self.criterion.cuda()
[docs]
def save(self, output_file: str) -> None:
"""Save a trainer checkpoint
Args:
output_file: destination file path
"""
torch.save(dict(epoch=self.epoch, step=self.step, min_loss=self.min_loss,
optimizer=self.optimizer.state_dict(),
model=self.model.state_dict()),
output_file,
_use_new_zipfile_serialization=False)
[docs]
def load(self, state: Dict[str, Any]) -> None:
"""Resume from a trainer state
Args:
state (dict): checkpoint dictionary
"""
self.start_epoch = state['epoch']
self.epoch = self.start_epoch
self.step = state['step']
self.min_loss = state['min_loss']
self.optimizer.load_state_dict(state['optimizer'])
self.model.load_state_dict(state['model'])
def _fit_epoch(self, mb: ConsoleMasterBar) -> None:
"""Fit a single epoch
Args:
mb (fastprogress.master_bar): primary progress bar
"""
self.model = freeze_bn(self.model.train())
pb = progress_bar(self.train_loader, parent=mb)
for x, target in pb:
x, target = self.to_cuda(x, target)
# Forward
batch_loss = self._get_loss(x, target)
# Backprop
self._backprop_step(batch_loss)
# Update LR
self.scheduler.step()
pb.comment = f"Training loss: {batch_loss.item():.4}"
self.step += 1
self.epoch += 1
[docs]
def to_cuda(
self,
x: Tensor,
target: Union[Tensor, List[Dict[str, Tensor]]]
) -> Tuple[Tensor, Union[Tensor, List[Dict[str, Tensor]]]]:
"""Move input and target to GPU"""
if isinstance(self.gpu, int):
if self.gpu >= torch.cuda.device_count():
raise ValueError("Invalid device index")
return self._to_cuda(x, target) # type: ignore[arg-type]
else:
return x, target
@staticmethod
def _to_cuda(x: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Move input and target to GPU"""
x = x.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
return x, target
def _backprop_step(self, loss: Tensor) -> None:
# Clean gradients
self.optimizer.zero_grad()
# Backpropate the loss
loss.backward()
# Update the params
self.optimizer.step()
def _get_loss(self, x: Tensor, target: Tensor) -> Tensor:
# Forward
out = self.model(x)
# Loss computation
return self.criterion(out, target)
def _set_params(self) -> None:
self._params = ContiguousParams([p for p in self.model.parameters() if p.requires_grad])
def _reset_opt(self, lr: float) -> None:
"""Reset the target params of the optimizer"""
self.optimizer.defaults['lr'] = lr
self.optimizer.state = defaultdict(dict)
self.optimizer.param_groups = []
self._set_params()
self.optimizer.add_param_group(dict(params=self._params.contiguous())) # type: ignore[union-attr]
@torch.no_grad()
def evaluate(self):
raise NotImplementedError
@staticmethod
def _eval_metrics_str(eval_metrics):
raise NotImplementedError
def _reset_scheduler(self, lr: float, num_epochs: int, sched_type: str = 'onecycle') -> None:
if sched_type == 'onecycle':
self.scheduler = OneCycleLR(self.optimizer, lr, num_epochs * len(self.train_loader))
elif sched_type == 'cosine':
self.scheduler = CosineAnnealingLR(self.optimizer, num_epochs * len(self.train_loader), eta_min=lr / 25e4)
else:
raise ValueError(f"The following scheduler type is not supported: {sched_type}")
[docs]
def fit_n_epochs(
self,
num_epochs: int,
lr: float,
freeze_until: Optional[str] = None,
sched_type: str = 'onecycle'
) -> None:
"""Train the model for a given number of epochs
Args:
num_epochs (int): number of epochs to train
lr (float): learning rate to be used by the scheduler
freeze_until (str, optional): last layer to freeze
sched_type (str, optional): type of scheduler to use
"""
self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr)
# Scheduler
self._reset_scheduler(lr, num_epochs, sched_type)
mb = master_bar(range(num_epochs))
for _ in mb:
self._fit_epoch(mb)
# Check whether ops invalidated the buffer
self._params.assert_buffer_is_valid() # type: ignore[union-attr]
eval_metrics = self.evaluate()
# master bar
mb.main_bar.comment = f"Epoch {self.start_epoch + self.epoch}/{self.start_epoch + num_epochs}"
mb.write(f"Epoch {self.start_epoch + self.epoch}/{self.start_epoch + num_epochs} - "
f"{self._eval_metrics_str(eval_metrics)}")
if eval_metrics['val_loss'] < self.min_loss:
print(f"Validation loss decreased {self.min_loss:.4} --> "
f"{eval_metrics['val_loss']:.4}: saving state...")
self.min_loss = eval_metrics['val_loss']
self.save(self.output_file)
[docs]
def lr_find(
self,
freeze_until: Optional[str] = None,
start_lr: float = 1e-7,
end_lr: float = 1,
num_it: int = 100
) -> None:
"""Gridsearch the optimal learning rate for the training
Args:
freeze_until (str, optional): last layer to freeze
start_lr (float, optional): initial learning rate
end_lr (float, optional): final learning rate
num_it (int, optional): number of iterations to perform
"""
self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(start_lr)
gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
scheduler = MultiplicativeLR(self.optimizer, lambda step: gamma)
self.lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)]
self.loss_recorder = []
for batch_idx, (x, target) in enumerate(self.train_loader):
x, target = self.to_cuda(x, target)
# Forward
batch_loss = self._get_loss(x, target)
self._backprop_step(batch_loss)
# Update LR
scheduler.step()
# Record
self.loss_recorder.append(batch_loss.item())
# Stop after the number of iterations
if batch_idx + 1 == num_it:
break
[docs]
def plot_recorder(self, beta: float = 0.95, block: bool = True) -> None:
"""Display the results of the LR grid search
Args:
beta (float, optional): smoothing factor
block (bool, optional): whether the plot should block execution
"""
if len(self.lr_recorder) != len(self.loss_recorder) or len(self.lr_recorder) == 0:
raise AssertionError("Please run the `lr_find` method first")
# Exp moving average of loss
smoothed_losses = []
avg_loss = 0.
for idx, loss in enumerate(self.loss_recorder):
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1)))
plt.plot(self.lr_recorder[10:-5], smoothed_losses[10:-5])
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Training loss')
plt.grid(True, linestyle='--', axis='x')
plt.show(block=block)
[docs]
def check_setup(self, freeze_until: Optional[str] = None, lr: float = 3e-4, num_it: int = 100) -> bool:
"""Check whether you can overfit one batch
Args:
freeze_until (str, optional): last layer to freeze
lr (float, optional): learning rate to be used for training
num_it (int, optional): number of iterations to perform
"""
self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr)
prev_loss = math.inf
x, target = next(iter(self.train_loader))
x, target = self.to_cuda(x, target)
for _ in range(num_it):
# Forward
batch_loss = self._get_loss(x, target)
# Backprop
self._backprop_step(batch_loss)
# Check that loss decreases
if batch_loss.item() > prev_loss:
return False
prev_loss = batch_loss.item()
return True
[docs]
class ClassificationTrainer(Trainer):
"""Image classification trainer class
Args:
model (torch.nn.Module): model to train
train_loader (torch.utils.data.DataLoader): training loader
val_loader (torch.utils.data.DataLoader): validation loader
criterion (torch.nn.Module): loss criterion
optimizer (torch.optim.Optimizer): parameter optimizer
gpu (int, optional): index of the GPU to use
output_file (str, optional): path where checkpoints will be saved
"""
[docs]
@torch.no_grad()
def evaluate(self) -> Dict[str, float]:
"""Evaluate the model on the validation set
Returns:
dict: evaluation metrics
"""
self.model.eval()
val_loss, top1, top5, num_samples = 0., 0, 0, 0
for x, target in self.val_loader:
x, target = self.to_cuda(x, target)
# Forward
out = self.model(x)
# Loss computation
val_loss += self.criterion(out, target).item()
pred = out.topk(5, dim=1)[1]
correct = pred.eq(target.view(-1, 1).expand_as(pred))
top1 += correct[:, 0].sum().item()
top5 += correct.any(dim=1).sum().item()
num_samples += x.shape[0]
val_loss /= len(self.val_loader)
return dict(val_loss=val_loss, acc1=top1 / num_samples, acc5=top5 / num_samples)
@staticmethod
def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str:
return (f"Validation loss: {eval_metrics['val_loss']:.4} "
f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})")
[docs]
class SegmentationTrainer(Trainer):
"""Semantic segmentation trainer class
Args:
model (torch.nn.Module): model to train
train_loader (torch.utils.data.DataLoader): training loader
val_loader (torch.utils.data.DataLoader): validation loader
criterion (torch.nn.Module): loss criterion
optimizer (torch.optim.Optimizer): parameter optimizer
gpu (int, optional): index of the GPU to use
output_file (str, optional): path where checkpoints will be saved
"""
[docs]
@torch.no_grad()
def evaluate(self, ignore_index: int = 255) -> Dict[str, float]:
"""Evaluate the model on the validation set
Args:
ignore_index (int, optional): index of the class to ignore in evaluation
Returns:
dict: evaluation metrics
"""
self.model.eval()
val_loss, mean_iou = 0., 0.
for x, target in self.val_loader:
x, target = self.to_cuda(x, target)
# Forward
out = self.model(x)
# Loss computation
val_loss += self.criterion(out, target).item()
pred = out.argmax(dim=1)
tmp_iou, num_seg = 0, 0
for class_idx in torch.unique(target):
if class_idx != ignore_index:
inter = (pred[target == class_idx] == class_idx).sum().item()
tmp_iou += inter / ((pred == class_idx) | (target == class_idx)).sum().item()
num_seg += 1
mean_iou += tmp_iou / num_seg
val_loss /= len(self.val_loader)
mean_iou /= len(self.val_loader)
return dict(val_loss=val_loss, mean_iou=mean_iou)
@staticmethod
def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str:
return f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
def assign_iou(gt_boxes: Tensor, pred_boxes: Tensor, iou_threshold: float = 0.5) -> Tuple[List[int], List[int]]:
"""Assigns boxes by IoU"""
iou = box_iou(gt_boxes, pred_boxes)
iou = iou.max(dim=1)
gt_kept = iou.values >= iou_threshold
assign_unique = torch.unique(iou.indices[gt_kept])
# Filter
if iou.indices[gt_kept].shape[0] == assign_unique.shape[0]:
return torch.arange(gt_boxes.shape[0])[gt_kept], iou.indices[gt_kept] # type: ignore[return-value]
else:
gt_indices, pred_indices = [], []
for pred_idx in assign_unique:
selection = iou.values[gt_kept][iou.indices[gt_kept] == pred_idx].argmax()
gt_indices.append(torch.arange(gt_boxes.shape[0])[gt_kept][selection].item())
pred_indices.append(iou.indices[gt_kept][selection].item())
return gt_indices, pred_indices # type: ignore[return-value]
[docs]
class DetectionTrainer(Trainer):
"""Object detection trainer class
Args:
model (torch.nn.Module): model to train
train_loader (torch.utils.data.DataLoader): training loader
val_loader (torch.utils.data.DataLoader): validation loader
criterion (None): loss criterion
optimizer (torch.optim.Optimizer): parameter optimizer
gpu (int, optional): index of the GPU to use
output_file (str, optional): path where checkpoints will be saved
"""
@staticmethod
def _to_cuda( # type: ignore[override]
x: List[Tensor],
target: List[Dict[str, Tensor]]
) -> Tuple[List[Tensor], List[Dict[str, Tensor]]]:
"""Move input and target to GPU"""
x = [_x.cuda(non_blocking=True) for _x in x]
target = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in target]
return x, target
def _backprop_step(self, loss: Tensor, grad_clip: float = .1) -> None:
# Clean gradients
self.optimizer.zero_grad()
# Backpropate the loss
loss.backward()
# Safeguard for Gradient explosion
if isinstance(grad_clip, float):
torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
# Update the params
self.optimizer.step()
def _get_loss(self, x: List[Tensor], target: List[Dict[str, Tensor]]) -> Tensor: # type: ignore[override]
# Forward & loss computation
loss_dict = self.model(x, target)
return sum(loss_dict.values()) # type: ignore[return-value]
@staticmethod
def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str:
return (f"Loc error: {eval_metrics['loc_err']:.2%} | Clf error: {eval_metrics['clf_err']:.2%} | "
f"Det error: {eval_metrics['det_err']:.2%}")
[docs]
@torch.no_grad()
def evaluate(self, iou_threshold: float = 0.5) -> Dict[str, float]:
"""Evaluate the model on the validation set
Args:
iou_threshold (float, optional): IoU threshold for pair assignment
Returns:
dict: evaluation metrics
"""
self.model.eval()
loc_assigns = 0
correct, clf_error, loc_fn, loc_fp, num_samples = 0, 0, 0, 0, 0
for x, target in self.val_loader:
x, target = self.to_cuda(x, target)
detections = self.model(x)
for dets, t in zip(detections, target):
if t['boxes'].shape[0] > 0 and dets['boxes'].shape[0] > 0:
gt_indices, pred_indices = assign_iou(t['boxes'], dets['boxes'], iou_threshold)
loc_assigns += len(gt_indices)
_correct = (t['labels'][gt_indices] == dets['labels'][pred_indices]).sum().item()
else:
gt_indices, pred_indices = [], []
_correct = 0
correct += _correct
clf_error += len(gt_indices) - _correct
loc_fn += t['boxes'].shape[0] - len(gt_indices)
loc_fp += dets['boxes'].shape[0] - len(pred_indices)
num_samples += sum(t['boxes'].shape[0] for t in target)
nb_preds = num_samples - loc_fn + loc_fp
# Localization
loc_err = 1 - 2 * loc_assigns / (nb_preds + num_samples) if nb_preds + num_samples > 0 else 1.
# Classification
clf_err = 1 - correct / loc_assigns if loc_assigns > 0 else 1.
# End-to-end
det_err = 1 - 2 * correct / (nb_preds + num_samples) if nb_preds + num_samples > 0 else 1.
return dict(loc_err=loc_err, clf_err=clf_err, det_err=det_err, val_loss=loc_err)