holocron.trainer

holocron.trainer provides some basic objects for training purposes.

class holocron.trainer.Trainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]
check_setup(freeze_until: str | None = None, lr: float = 0.0003, num_it: int = 100) bool[source]

Check whether you can overfit one batch

Parameters:
  • freeze_until (str, optional) – last layer to freeze

  • lr (float, optional) – learning rate to be used for training

  • num_it (int, optional) – number of iterations to perform

fit_n_epochs(num_epochs: int, lr: float, freeze_until: str | None = None, sched_type: str = 'onecycle') None[source]

Train the model for a given number of epochs

Parameters:
  • num_epochs (int) – number of epochs to train

  • lr (float) – learning rate to be used by the scheduler

  • freeze_until (str, optional) – last layer to freeze

  • sched_type (str, optional) – type of scheduler to use

load(state: Dict[str, Any]) None[source]

Resume from a trainer state

Parameters:

state (dict) – checkpoint dictionary

lr_find(freeze_until: str | None = None, start_lr: float = 1e-07, end_lr: float = 1, num_it: int = 100) None[source]

Gridsearch the optimal learning rate for the training

Parameters:
  • freeze_until (str, optional) – last layer to freeze

  • start_lr (float, optional) – initial learning rate

  • end_lr (float, optional) – final learning rate

  • num_it (int, optional) – number of iterations to perform

plot_recorder(beta: float = 0.95, block: bool = True) None[source]

Display the results of the LR grid search

Parameters:
  • beta (float, optional) – smoothing factor

  • block (bool, optional) – whether the plot should block execution

save(output_file: str) None[source]

Save a trainer checkpoint

Parameters:

output_file – destination file path

set_device(gpu: int | None = None) None[source]

Move tensor objects to the target GPU

Parameters:

gpu – index of the target GPU device

to_cuda(x: Tensor, target: Tensor | List[Dict[str, Tensor]]) Tuple[Tensor, Tensor | List[Dict[str, Tensor]]][source]

Move input and target to GPU

Image classification

class holocron.trainer.ClassificationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]

Image classification trainer class

Parameters:
evaluate() Dict[str, float][source]

Evaluate the model on the validation set

Returns:

evaluation metrics

Return type:

dict

class holocron.trainer.SegmentationTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]

Semantic segmentation trainer class

Parameters:
evaluate(ignore_index: int = 255) Dict[str, float][source]

Evaluate the model on the validation set

Parameters:

ignore_index (int, optional) – index of the class to ignore in evaluation

Returns:

evaluation metrics

Return type:

dict

class holocron.trainer.DetectionTrainer(model: Module, train_loader: DataLoader, val_loader: DataLoader, criterion: Module, optimizer: Optimizer, gpu: int | None = None, output_file: str = './checkpoint.pth')[source]

Object detection trainer class

Parameters:
evaluate(iou_threshold: float = 0.5) Dict[str, float][source]

Evaluate the model on the validation set

Parameters:

iou_threshold (float, optional) – IoU threshold for pair assignment

Returns:

evaluation metrics

Return type:

dict

Miscellaneous

holocron.trainer.freeze_bn(mod: Module) Module[source]

Prevents parameter and stats from updating in Batchnorm layers that are frozen

Parameters:

mod (torch.nn.Module) – model to train

Returns:

model

Return type:

torch.nn.Module

holocron.trainer.freeze_model(model: Module, last_frozen_layer: str | None = None, frozen_bn_stat_update: bool = False) Module[source]

Freeze a specific range of model layers

Parameters:
  • model (torch.nn.Module) – model to train

  • last_frozen_layer (str, optional) – last layer to freeze. Assumes layers have been registered in forward order

  • frozen_bn_stat_update (bool, optional) – force stats update in BN layers that are frozen

Returns:

model

Return type:

torch.nn.Module