# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License version 2.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.fromtypingimportTupleimportnumpyasnpimporttorchfromtorchimportTensorfromtorch.nn.functionalimportone_hot__all__=['Mixup']
[docs]classMixup(torch.nn.Module):"""Implements a batch collate function with MixUp strategy from `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_ Args: num_classes: number of expected classes alpha: mixup factor Example:: >>> import torch >>> from torch.utils.data._utils.collate import default_collate >>> from holocron.utils.data import Mixup >>> mix = Mixup(num_classes=10, alpha=0.4) >>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=lambda b: mix(*default_collate(b))) """def__init__(self,num_classes:int,alpha:float=0.2)->None:super().__init__()self.num_classes=num_classesifalpha<0:raiseValueError("`alpha` only takes positive values")self.alpha=alphadefforward(self,inputs:Tensor,targets:Tensor)->Tuple[Tensor,Tensor]:# Convert target to one-hottargets=one_hot(targets,num_classes=self.num_classes).to(dtype=inputs.dtype)# Sample lambdaifself.alpha==0:returninputs,targetslam=np.random.beta(self.alpha,self.alpha)# Mix batch indicesbatch_size=inputs.size()[0]index=torch.randperm(batch_size)# Create the new input and targetsmixed_input,mixed_target=inputs[index,:],targets[index]mixed_input.mul_(1-lam)inputs.mul_(lam).add_(mixed_input)mixed_target.mul_(1-lam)targets.mul_(lam).add_(mixed_target)returninputs,targets