[docs]defset_device(self,gpu:Optional[int]=None)->None:"""Move tensor objects to the target GPU Args: gpu: index of the target GPU device """ifisinstance(gpu,int):ifnottorch.cuda.is_available():raiseAssertionError("PyTorch cannot access your GPU. Please investigate!")ifgpu>=torch.cuda.device_count():raiseValueError("Invalid device index")torch.cuda.set_device(gpu)self.model=self.model.cuda()ifisinstance(self.criterion,torch.nn.Module):self.criterion=self.criterion.cuda()
[docs]defsave(self,output_file:str)->None:"""Save a trainer checkpoint Args: output_file: destination file path """torch.save(dict(epoch=self.epoch,step=self.step,min_loss=self.min_loss,optimizer=self.optimizer.state_dict(),model=self.model.state_dict()),output_file,_use_new_zipfile_serialization=False)
[docs]defload(self,state:Dict[str,Any])->None:"""Resume from a trainer state Args: state (dict): checkpoint dictionary """self.start_epoch=state['epoch']self.epoch=self.start_epochself.step=state['step']self.min_loss=state['min_loss']self.optimizer.load_state_dict(state['optimizer'])self.model.load_state_dict(state['model'])
def_fit_epoch(self,mb:ConsoleMasterBar)->None:"""Fit a single epoch Args: mb (fastprogress.master_bar): primary progress bar """self.model=freeze_bn(self.model.train())pb=progress_bar(self.train_loader,parent=mb)forx,targetinpb:x,target=self.to_cuda(x,target)# Forwardbatch_loss=self._get_loss(x,target)# Backpropself._backprop_step(batch_loss)# Update LRself.scheduler.step()pb.comment=f"Training loss: {batch_loss.item():.4}"self.step+=1self.epoch+=1
[docs]defto_cuda(self,x:Tensor,target:Union[Tensor,List[Dict[str,Tensor]]])->Tuple[Tensor,Union[Tensor,List[Dict[str,Tensor]]]]:"""Move input and target to GPU"""ifisinstance(self.gpu,int):ifself.gpu>=torch.cuda.device_count():raiseValueError("Invalid device index")returnself._to_cuda(x,target)# type: ignore[arg-type]else:returnx,target
@staticmethoddef_to_cuda(x:Tensor,target:Tensor)->Tuple[Tensor,Tensor]:"""Move input and target to GPU"""x=x.cuda(non_blocking=True)target=target.cuda(non_blocking=True)returnx,targetdef_backprop_step(self,loss:Tensor)->None:# Clean gradientsself.optimizer.zero_grad()# Backpropate the lossloss.backward()# Update the paramsself.optimizer.step()def_get_loss(self,x:Tensor,target:Tensor)->Tensor:# Forwardout=self.model(x)# Loss computationreturnself.criterion(out,target)def_set_params(self)->None:self._params=ContiguousParams([pforpinself.model.parameters()ifp.requires_grad])def_reset_opt(self,lr:float)->None:"""Reset the target params of the optimizer"""self.optimizer.defaults['lr']=lrself.optimizer.state=defaultdict(dict)self.optimizer.param_groups=[]self._set_params()self.optimizer.add_param_group(dict(params=self._params.contiguous()))# type: ignore[union-attr]@torch.no_grad()defevaluate(self):raiseNotImplementedError@staticmethoddef_eval_metrics_str(eval_metrics):raiseNotImplementedErrordef_reset_scheduler(self,lr:float,num_epochs:int,sched_type:str='onecycle')->None:ifsched_type=='onecycle':self.scheduler=OneCycleLR(self.optimizer,lr,num_epochs*len(self.train_loader))elifsched_type=='cosine':self.scheduler=CosineAnnealingLR(self.optimizer,num_epochs*len(self.train_loader),eta_min=lr/25e4)else:raiseValueError(f"The following scheduler type is not supported: {sched_type}")
[docs]deffit_n_epochs(self,num_epochs:int,lr:float,freeze_until:Optional[str]=None,sched_type:str='onecycle')->None:"""Train the model for a given number of epochs Args: num_epochs (int): number of epochs to train lr (float): learning rate to be used by the scheduler freeze_until (str, optional): last layer to freeze sched_type (str, optional): type of scheduler to use """self.model=freeze_model(self.model.train(),freeze_until)# Update param groups & LRself._reset_opt(lr)# Schedulerself._reset_scheduler(lr,num_epochs,sched_type)mb=master_bar(range(num_epochs))for_inmb:self._fit_epoch(mb)# Check whether ops invalidated the bufferself._params.assert_buffer_is_valid()# type: ignore[union-attr]eval_metrics=self.evaluate()# master barmb.main_bar.comment=f"Epoch {self.start_epoch+self.epoch}/{self.start_epoch+num_epochs}"mb.write(f"Epoch {self.start_epoch+self.epoch}/{self.start_epoch+num_epochs} - "f"{self._eval_metrics_str(eval_metrics)}")ifeval_metrics['val_loss']<self.min_loss:print(f"Validation loss decreased {self.min_loss:.4} --> "f"{eval_metrics['val_loss']:.4}: saving state...")self.min_loss=eval_metrics['val_loss']self.save(self.output_file)
[docs]deflr_find(self,freeze_until:Optional[str]=None,start_lr:float=1e-7,end_lr:float=1,num_it:int=100)->None:"""Gridsearch the optimal learning rate for the training Args: freeze_until (str, optional): last layer to freeze start_lr (float, optional): initial learning rate end_lr (float, optional): final learning rate num_it (int, optional): number of iterations to perform """self.model=freeze_model(self.model.train(),freeze_until)# Update param groups & LRself._reset_opt(start_lr)gamma=(end_lr/start_lr)**(1/(num_it-1))scheduler=MultiplicativeLR(self.optimizer,lambdastep:gamma)self.lr_recorder=[start_lr*gamma**idxforidxinrange(num_it)]self.loss_recorder=[]forbatch_idx,(x,target)inenumerate(self.train_loader):x,target=self.to_cuda(x,target)# Forwardbatch_loss=self._get_loss(x,target)self._backprop_step(batch_loss)# Update LRscheduler.step()# Recordself.loss_recorder.append(batch_loss.item())# Stop after the number of iterationsifbatch_idx+1==num_it:break
[docs]defplot_recorder(self,beta:float=0.95,block:bool=True)->None:"""Display the results of the LR grid search Args: beta (float, optional): smoothing factor block (bool, optional): whether the plot should block execution """iflen(self.lr_recorder)!=len(self.loss_recorder)orlen(self.lr_recorder)==0:raiseAssertionError("Please run the `lr_find` method first")# Exp moving average of losssmoothed_losses=[]avg_loss=0.foridx,lossinenumerate(self.loss_recorder):avg_loss=beta*avg_loss+(1-beta)*losssmoothed_losses.append(avg_loss/(1-beta**(idx+1)))plt.plot(self.lr_recorder[10:-5],smoothed_losses[10:-5])plt.xscale('log')plt.xlabel('Learning Rate')plt.ylabel('Training loss')plt.grid(True,linestyle='--',axis='x')plt.show(block=block)
[docs]defcheck_setup(self,freeze_until:Optional[str]=None,lr:float=3e-4,num_it:int=100)->bool:"""Check whether you can overfit one batch Args: freeze_until (str, optional): last layer to freeze lr (float, optional): learning rate to be used for training num_it (int, optional): number of iterations to perform """self.model=freeze_model(self.model.train(),freeze_until)# Update param groups & LRself._reset_opt(lr)prev_loss=math.infx,target=next(iter(self.train_loader))x,target=self.to_cuda(x,target)for_inrange(num_it):# Forwardbatch_loss=self._get_loss(x,target)# Backpropself._backprop_step(batch_loss)# Check that loss decreasesifbatch_loss.item()>prev_loss:returnFalseprev_loss=batch_loss.item()returnTrue
[docs]classClassificationTrainer(Trainer):"""Image classification trainer class Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (torch.nn.Module): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved """
[docs]@torch.no_grad()defevaluate(self)->Dict[str,float]:"""Evaluate the model on the validation set Returns: dict: evaluation metrics """self.model.eval()val_loss,top1,top5,num_samples=0.,0,0,0forx,targetinself.val_loader:x,target=self.to_cuda(x,target)# Forwardout=self.model(x)# Loss computationval_loss+=self.criterion(out,target).item()pred=out.topk(5,dim=1)[1]correct=pred.eq(target.view(-1,1).expand_as(pred))top1+=correct[:,0].sum().item()top5+=correct.any(dim=1).sum().item()num_samples+=x.shape[0]val_loss/=len(self.val_loader)returndict(val_loss=val_loss,acc1=top1/num_samples,acc5=top5/num_samples)
[docs]classSegmentationTrainer(Trainer):"""Semantic segmentation trainer class Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (torch.nn.Module): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved """
[docs]@torch.no_grad()defevaluate(self,ignore_index:int=255)->Dict[str,float]:"""Evaluate the model on the validation set Args: ignore_index (int, optional): index of the class to ignore in evaluation Returns: dict: evaluation metrics """self.model.eval()val_loss,mean_iou=0.,0.forx,targetinself.val_loader:x,target=self.to_cuda(x,target)# Forwardout=self.model(x)# Loss computationval_loss+=self.criterion(out,target).item()pred=out.argmax(dim=1)tmp_iou,num_seg=0,0forclass_idxintorch.unique(target):ifclass_idx!=ignore_index:inter=(pred[target==class_idx]==class_idx).sum().item()tmp_iou+=inter/((pred==class_idx)|(target==class_idx)).sum().item()num_seg+=1mean_iou+=tmp_iou/num_segval_loss/=len(self.val_loader)mean_iou/=len(self.val_loader)returndict(val_loss=val_loss,mean_iou=mean_iou)
defassign_iou(gt_boxes:Tensor,pred_boxes:Tensor,iou_threshold:float=0.5)->Tuple[List[int],List[int]]:"""Assigns boxes by IoU"""iou=box_iou(gt_boxes,pred_boxes)iou=iou.max(dim=1)gt_kept=iou.values>=iou_thresholdassign_unique=torch.unique(iou.indices[gt_kept])# Filterifiou.indices[gt_kept].shape[0]==assign_unique.shape[0]:returntorch.arange(gt_boxes.shape[0])[gt_kept],iou.indices[gt_kept]# type: ignore[return-value]else:gt_indices,pred_indices=[],[]forpred_idxinassign_unique:selection=iou.values[gt_kept][iou.indices[gt_kept]==pred_idx].argmax()gt_indices.append(torch.arange(gt_boxes.shape[0])[gt_kept][selection].item())pred_indices.append(iou.indices[gt_kept][selection].item())returngt_indices,pred_indices# type: ignore[return-value]
[docs]classDetectionTrainer(Trainer):"""Object detection trainer class Args: model (torch.nn.Module): model to train train_loader (torch.utils.data.DataLoader): training loader val_loader (torch.utils.data.DataLoader): validation loader criterion (None): loss criterion optimizer (torch.optim.Optimizer): parameter optimizer gpu (int, optional): index of the GPU to use output_file (str, optional): path where checkpoints will be saved """@staticmethoddef_to_cuda(# type: ignore[override]x:List[Tensor],target:List[Dict[str,Tensor]])->Tuple[List[Tensor],List[Dict[str,Tensor]]]:"""Move input and target to GPU"""x=[_x.cuda(non_blocking=True)for_xinx]target=[{k:v.cuda(non_blocking=True)fork,vint.items()}fortintarget]returnx,targetdef_backprop_step(self,loss:Tensor,grad_clip:float=.1)->None:# Clean gradientsself.optimizer.zero_grad()# Backpropate the lossloss.backward()# Safeguard for Gradient explosionifisinstance(grad_clip,float):torch.nn.utils.clip_grad_norm_(self.model.parameters(),grad_clip)# Update the paramsself.optimizer.step()def_get_loss(self,x:List[Tensor],target:List[Dict[str,Tensor]])->Tensor:# type: ignore[override]# Forward & loss computationloss_dict=self.model(x,target)returnsum(loss_dict.values())# type: ignore[return-value]@staticmethoddef_eval_metrics_str(eval_metrics:Dict[str,float])->str:return(f"Loc error: {eval_metrics['loc_err']:.2%} | Clf error: {eval_metrics['clf_err']:.2%} | "f"Det error: {eval_metrics['det_err']:.2%}")
[docs]@torch.no_grad()defevaluate(self,iou_threshold:float=0.5)->Dict[str,float]:"""Evaluate the model on the validation set Args: iou_threshold (float, optional): IoU threshold for pair assignment Returns: dict: evaluation metrics """self.model.eval()loc_assigns=0correct,clf_error,loc_fn,loc_fp,num_samples=0,0,0,0,0forx,targetinself.val_loader:x,target=self.to_cuda(x,target)detections=self.model(x)fordets,tinzip(detections,target):ift['boxes'].shape[0]>0anddets['boxes'].shape[0]>0:gt_indices,pred_indices=assign_iou(t['boxes'],dets['boxes'],iou_threshold)loc_assigns+=len(gt_indices)_correct=(t['labels'][gt_indices]==dets['labels'][pred_indices]).sum().item()else:gt_indices,pred_indices=[],[]_correct=0correct+=_correctclf_error+=len(gt_indices)-_correctloc_fn+=t['boxes'].shape[0]-len(gt_indices)loc_fp+=dets['boxes'].shape[0]-len(pred_indices)num_samples+=sum(t['boxes'].shape[0]fortintarget)nb_preds=num_samples-loc_fn+loc_fp# Localizationloc_err=1-2*loc_assigns/(nb_preds+num_samples)ifnb_preds+num_samples>0else1.# Classificationclf_err=1-correct/loc_assignsifloc_assigns>0else1.# End-to-enddet_err=1-2*correct/(nb_preds+num_samples)ifnb_preds+num_samples>0else1.returndict(loc_err=loc_err,clf_err=clf_err,det_err=det_err,val_loss=loc_err)