holocron.trainer¶
holocron.trainer
provides some basic objects for training purposes.
- class holocron.trainer.Trainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
- check_setup(freeze_until: str | None = None, lr: float = 0.0003, norm_weight_decay: float | None = None, num_it: int = 100) bool [source]¶
Check whether you can overfit one batch
- find_lr(freeze_until: str | None = None, start_lr: float = 1e-07, end_lr: float = 1, norm_weight_decay: float | None = None, num_it: int = 100) None [source]¶
Gridsearch the optimal learning rate for the training
- Parameters:
- fit_n_epochs(num_epochs: int, lr: float, freeze_until: str | None = None, sched_type: str = 'onecycle', norm_weight_decay: float | None = None) None [source]¶
Train the model for a given number of epochs
- Parameters:
- load(state: Dict[str, Any]) None [source]¶
Resume from a trainer state
- Parameters:
state (dict) – checkpoint dictionary
- plot_recorder(beta: float = 0.95, **kwargs: Any) None [source]¶
Display the results of the LR grid search
- Parameters:
beta (float, optional) – smoothing factor
- save(output_file: str) None [source]¶
Save a trainer checkpoint
- Parameters:
output_file – destination file path
Image classification¶
- class holocron.trainer.ClassificationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Image classification trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (torch.nn.Module) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
amp (bool, optional) – whether to use automatic mixed precision
skip_nan_loss (bool, optional) – whether the optimizer step should be skipped when the loss is NaN
on_epoch_end (Callable[[Dict[str, float]], Any]) – callback triggered at the end of an epoch
- class holocron.trainer.BinaryClassificationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Image binary classification trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (torch.nn.Module) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
amp (bool, optional) – whether to use automatic mixed precision
- class holocron.trainer.SegmentationTrainer(*args: Any, num_classes: int = 10, **kwargs: Any)[source]¶
Semantic segmentation trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (torch.nn.Module) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
num_classes (int) – number of output classes
amp (bool, optional) – whether to use automatic mixed precision
skip_nan_loss (bool, optional) – whether the optimizer step should be skipped when the loss is NaN
on_epoch_end (Callable[[Dict[str, float]], Any]) – callback triggered at the end of an epoch
- class holocron.trainer.DetectionTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Object detection trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (None) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
amp (bool, optional) – whether to use automatic mixed precision
skip_nan_loss (bool, optional) – whether the optimizer step should be skipped when the loss is NaN
on_epoch_end (Callable[[Dict[str, float]], Any]) – callback triggered at the end of an epoch
Miscellaneous¶
- holocron.trainer.freeze_bn(mod: Module) Module [source]¶
Prevents parameter and stats from updating in Batchnorm layers that are frozen
- Parameters:
mod (torch.nn.Module) – model to train
- Returns:
model
- Return type:
- holocron.trainer.freeze_model(model: Module, last_frozen_layer: str | None = None, frozen_bn_stat_update: bool = False) Module [source]¶
Freeze a specific range of model layers
- Parameters:
model (torch.nn.Module) – model to train
last_frozen_layer (str, optional) – last layer to freeze. Assumes layers have been registered in forward order
frozen_bn_stat_update (bool, optional) – force stats update in BN layers that are frozen
- Returns:
model
- Return type: