# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportAny,Dictimporttorchfrom.coreimportTrainer__all__=["SegmentationTrainer"]
[docs]classSegmentationTrainer(Trainer):"""Semantic segmentation trainer class Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (torch.nn.Module): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved num_classes (int): number of output classes amp (bool, optional): whether to use automatic mixed precision skip_nan_loss (bool, optional): whether the optimizer step should be skipped when the loss is NaN on_epoch_end (Callable[[Dict[str, float]], Any]): callback triggered at the end of an epoch """def__init__(self,*args:Any,num_classes:int=10,**kwargs:Any)->None:super().__init__(*args,**kwargs)self.num_classes=num_classes
[docs]@torch.inference_mode()defevaluate(self,ignore_index:int=255)->Dict[str,float]:"""Evaluate the model on the validation set Args: ignore_index (int, optional): index of the class to ignore in evaluation Returns: dict: evaluation metrics """self.model.eval()val_loss,mean_iou,num_valid_batches=0.0,0.0,0conf_mat=torch.zeros((self.num_classes,self.num_classes),dtype=torch.int64,device=next(self.model.parameters()).device)forx,targetinself.val_loader:x,target=self.to_cuda(x,target)_loss,out=self._get_loss(x,target,return_logits=True)# Safeguard for NaN lossifnottorch.isnan(_loss)andnottorch.isinf(_loss):val_loss+=_loss.item()num_valid_batches+=1# borrowed from https://github.com/pytorch/vision/blob/master/references/segmentation/train.pypred=out.argmax(dim=1).flatten()target=target.flatten()k=(target>=0)&(target<self.num_classes)inds=self.num_classes*target[k].to(torch.int64)+pred[k]nc=self.num_classesconf_mat+=torch.bincount(inds,minlength=nc**2).reshape(nc,nc)val_loss/=num_valid_batchesacc_global=(torch.diag(conf_mat).sum()/conf_mat.sum()).item()mean_iou=(torch.diag(conf_mat)/(conf_mat.sum(1)+conf_mat.sum(0)-torch.diag(conf_mat))).mean().item()returndict(val_loss=val_loss,acc_global=acc_global,mean_iou=mean_iou)
@staticmethoddef_eval_metrics_str(eval_metrics:Dict[str,float])->str:return(f"Validation loss: {eval_metrics['val_loss']:.4} "f"(Acc: {eval_metrics['acc_global']:.2%} | Mean IoU: {eval_metrics['mean_iou']:.2%})")