holocron.trainer¶
holocron.trainer
provides some basic objects for training purposes.
- class holocron.trainer.Trainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, gradient_acc: int = 1, gradient_clip: float | None = None, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Baseline trainer class.
- Parameters:
model – model to train
train_loader – training loader
val_loader – validation loader
criterion – loss criterion
optimizer – parameter optimizer
gpu – index of the GPU to use
output_file – path where checkpoints will be saved
amp – whether to use automatic mixed precision
skip_nan_loss – whether the optimizer step should be skipped when the loss is NaN
nan_tolerance – number of consecutive batches with NaN loss before stopping the training
gradient_acc – number of batches to accumulate the gradient of before performing the update step
gradient_clip – the gradient clip value
on_epoch_end – callback triggered at the end of an epoch
- check_setup(freeze_until: str | None = None, lr: float = 0.0003, norm_weight_decay: float | None = None, num_it: int = 100, **kwargs: Any) None [source]¶
Check whether you can overfit one batch
- Parameters:
freeze_until (str, optional) – last layer to freeze
lr (float, optional) – learning rate to be used for training
norm_weight_decay (float, optional) – weight decay to apply to normalization parameters
num_it (int, optional) – number of iterations to perform
kwargs – keyword args of matplotlib.pyplot.show
- find_lr(freeze_until: str | None = None, start_lr: float = 1e-07, end_lr: float = 1, norm_weight_decay: float | None = None, num_it: int = 100) None [source]¶
Gridsearch the optimal learning rate for the training as described in “Cyclical Learning Rates for Training Neural Networks”.
- Parameters:
- fit_n_epochs(num_epochs: int, lr: float, freeze_until: str | None = None, sched_type: str = 'onecycle', norm_weight_decay: float | None = None, **kwargs: Any) None [source]¶
Train the model for a given number of epochs.
- Parameters:
num_epochs (int) – number of epochs to train
lr (float) – learning rate to be used by the scheduler
freeze_until (str, optional) – last layer to freeze
sched_type (str, optional) – type of scheduler to use
norm_weight_decay (float, optional) – weight decay to apply to normalization parameters
**kwargs – keyword args passed to the schedulers
- load(state: Dict[str, Any]) None [source]¶
Resume from a trainer state
- Parameters:
state (dict) – checkpoint dictionary
- plot_recorder(beta: float = 0.95, **kwargs: Any) None [source]¶
Display the results of the LR grid search
- Parameters:
beta (float, optional) – smoothing factor
kwargs – keyword args of matplotlib.pyplot.show
- save(output_file: str) None [source]¶
Save a trainer checkpoint
- Parameters:
output_file – destination file path
Image classification¶
- class holocron.trainer.ClassificationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, gradient_acc: int = 1, gradient_clip: float | None = None, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Image classification trainer class.
- Parameters:
model – model to train
train_loader – training loader
val_loader – validation loader
criterion – loss criterion
optimizer – parameter optimizer
gpu – index of the GPU to use
output_file – path where checkpoints will be saved
amp – whether to use automatic mixed precision
skip_nan_loss – whether the optimizer step should be skipped when the loss is NaN
nan_tolerance – number of consecutive batches with NaN loss before stopping the training
gradient_acc – number of batches to accumulate the gradient of before performing the update step
gradient_clip – the gradient clip value
on_epoch_end – callback triggered at the end of an epoch
- class holocron.trainer.BinaryClassificationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, gradient_acc: int = 1, gradient_clip: float | None = None, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Image binary classification trainer class.
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (torch.nn.Module) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
amp (bool, optional) – whether to use automatic mixed precision
Semantic segmentation¶
- class holocron.trainer.SegmentationTrainer(*args: Any, num_classes: int = 10, **kwargs: Any)[source]¶
Semantic segmentation trainer class.
- Parameters:
model – model to train
train_loader – training loader
val_loader – validation loader
criterion – loss criterion
optimizer – parameter optimizer
gpu – index of the GPU to use
output_file – path where checkpoints will be saved
amp – whether to use automatic mixed precision
skip_nan_loss – whether the optimizer step should be skipped when the loss is NaN
nan_tolerance – number of consecutive batches with NaN loss before stopping the training
gradient_acc – number of batches to accumulate the gradient of before performing the update step
gradient_clip – the gradient clip value
on_epoch_end – callback triggered at the end of an epoch
Object detection¶
- class holocron.trainer.DetectionTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth', amp: bool = False, skip_nan_loss: bool = False, nan_tolerance: int = 5, gradient_acc: int = 1, gradient_clip: float | None = None, on_epoch_end: Callable[[Dict[str, float]], Any] | None = None)[source]¶
Object detection trainer class.
- Parameters:
model – model to train
train_loader – training loader
val_loader – validation loader
criterion – loss criterion
optimizer – parameter optimizer
gpu – index of the GPU to use
output_file – path where checkpoints will be saved
amp – whether to use automatic mixed precision
skip_nan_loss – whether the optimizer step should be skipped when the loss is NaN
nan_tolerance – number of consecutive batches with NaN loss before stopping the training
gradient_acc – number of batches to accumulate the gradient of before performing the update step
gradient_clip – the gradient clip value
on_epoch_end – callback triggered at the end of an epoch
Miscellaneous¶
- holocron.trainer.freeze_bn(mod: Module) None [source]¶
Prevents parameter and stats from updating in Batchnorm layers that are frozen
>>> from holocron.models import rexnet1_0x >>> from holocron.trainer.utils import freeze_bn >>> model = rexnet1_0x() >>> freeze_bn(model)
- Parameters:
mod (torch.nn.Module) – model to train
- holocron.trainer.freeze_model(model: Module, last_frozen_layer: str | None = None, frozen_bn_stat_update: bool = False) None [source]¶
Freeze a specific range of model layers.
>>> from holocron.models import rexnet1_0x >>> from holocron.trainer.utils import freeze_model >>> model = rexnet1_0x() >>> freeze_model(model)
- Parameters:
model (torch.nn.Module) – model to train
last_frozen_layer (str, optional) – last layer to freeze. Assumes layers have been registered in forward order
frozen_bn_stat_update (bool, optional) – force stats update in BN layers that are frozen