holocron.trainer¶
holocron.trainer
provides some basic objects for training purposes.
- class holocron.trainer.Trainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]¶
- check_setup(freeze_until: str | None = None, lr: float = 0.0003, num_it: int = 100) bool [source]¶
Check whether you can overfit one batch
- fit_n_epochs(num_epochs: int, lr: float, freeze_until: str | None = None, sched_type: str = 'onecycle') None [source]¶
Train the model for a given number of epochs
- load(state: Dict[str, Any]) None [source]¶
Resume from a trainer state
- Parameters:
state (dict) – checkpoint dictionary
- lr_find(freeze_until: str | None = None, start_lr: float = 1e-07, end_lr: float = 1, num_it: int = 100) None [source]¶
Gridsearch the optimal learning rate for the training
- plot_recorder(beta: float = 0.95, block: bool = True) None [source]¶
Display the results of the LR grid search
- save(output_file: str) None [source]¶
Save a trainer checkpoint
- Parameters:
output_file – destination file path
Image classification¶
- class holocron.trainer.ClassificationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]¶
Image classification trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (torch.nn.Module) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
- class holocron.trainer.SegmentationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]¶
Semantic segmentation trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (torch.nn.Module) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
- class holocron.trainer.DetectionTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]¶
Object detection trainer class
- Parameters:
model (torch.nn.Module) – model to train
train_loader (torch.utils.data.DataLoader) – training loader
val_loader (torch.utils.data.DataLoader) – validation loader
criterion (None) – loss criterion
optimizer (torch.optim.Optimizer) – parameter optimizer
gpu (int, optional) – index of the GPU to use
output_file (str, optional) – path where checkpoints will be saved
Miscellaneous¶
- holocron.trainer.freeze_bn(mod: Module) Module [source]¶
Prevents parameter and stats from updating in Batchnorm layers that are frozen
- Parameters:
mod (torch.nn.Module) – model to train
- Returns:
model
- Return type:
- holocron.trainer.freeze_model(model: Module, last_frozen_layer: str | None = None, frozen_bn_stat_update: bool = False) Module [source]¶
Freeze a specific range of model layers
- Parameters:
model (torch.nn.Module) – model to train
last_frozen_layer (str, optional) – last layer to freeze. Assumes layers have been registered in forward order
frozen_bn_stat_update (bool, optional) – force stats update in BN layers that are frozen
- Returns:
model
- Return type: