Source code for holocron.models.classification.repvgg
# Copyright (C) 2021-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License version 2.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.fromcollectionsimportOrderedDictfromtypingimportAny,Callable,Dict,List,Optional,Unionimporttorchimporttorch.nnasnnfromholocron.nnimportGlobalAvgPool2dfrom..presetsimportIMAGENETTEfrom..utilsimportfuse_conv_bn,load_pretrained_params__all__=['RepVGG','RepBlock','RepVGG','repvgg_a0','repvgg_a1','repvgg_a2','repvgg_b0','repvgg_b1','repvgg_b2','repvgg_b3']default_cfgs:Dict[str,Dict[str,Any]]={'repvgg_a0':{**IMAGENETTE,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_a0_224-150f4b9d.pt'},'repvgg_a1':{**IMAGENETTE,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_a1_224-870b9e4b.pt'},'repvgg_a2':{**IMAGENETTE,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_a2_224-7051289a.pt'},'repvgg_b0':{**IMAGENETTE,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_b0_224-7e9c3fc7.pth'},'repvgg_b1':{**IMAGENETTE,'input_shape':(3,224,224),'url':None,},'repvgg_b2':{**IMAGENETTE,'input_shape':(3,224,224),'url':None,},'repvgg_b3':{**IMAGENETTE,'input_shape':(3,224,224),'url':None,},}classRepBlock(nn.Module):def__init__(self,inplanes:int,planes:int,stride:int=1,identity:bool=True,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,)->None:super().__init__()ifnorm_layerisNone:norm_layer=nn.BatchNorm2difact_layerisNone:act_layer=nn.ReLU(inplace=True)self.branches:Union[nn.Conv2d,nn.ModuleList]=nn.ModuleList([nn.Sequential(nn.Conv2d(inplanes,planes,3,padding=1,bias=(norm_layerisNone),stride=stride),norm_layer(planes),),nn.Sequential(nn.Conv2d(inplanes,planes,1,padding=0,bias=(norm_layerisNone),stride=stride),norm_layer(planes),),])self.activation=act_layerifidentity:ifinplanes!=planes:raiseValueError("The number of input and output channels must be identical if identity is used")self.branches.append(nn.BatchNorm2d(planes))defforward(self,x:torch.Tensor)->torch.Tensor:ifisinstance(self.branches,nn.Conv2d):out=self.branches(x)else:out=sum(branch(x)forbranchinself.branches)returnself.activation(out)defreparametrize(self)->None:"""Reparametrize the block by fusing convolutions and BN in each branch, then fusing all branches"""ifnotisinstance(self.branches,nn.ModuleList):raiseAssertionErrorinplanes=self.branches[0][0].weight.data.shape[1]planes=self.branches[0][0].weight.data.shape[0]# Instantiate the equivalent Conv 3x3rep=nn.Conv2d(inplanes,planes,3,padding=1,bias=True,stride=self.branches[0][0].stride)# Fuse convolutions with their BNfused_k3,fused_b3=fuse_conv_bn(*self.branches[0])fused_k1,fused_b1=fuse_conv_bn(*self.branches[1])# Conv 3x3rep.weight.data=fused_k3rep.bias.data=fused_b3# type: ignore[union-attr]# Conv 1x1rep.weight.data[...,1:2,1:2]+=fused_k1rep.bias.data+=fused_b1# type: ignore[union-attr]# Identityiflen(self.branches)==3:scale_factor=self.branches[2].weight.data/(self.branches[2].running_var+self.branches[2].eps).sqrt()# Identity is mapped as a diagonal matrix relatively to the out/in channel dimensionsrep.weight.data[range(planes),range(inplanes),1,1]+=scale_factorrep.bias.data+=self.branches[2].bias.data# type: ignore[union-attr]rep.bias.data-=scale_factor*self.branches[2].running_mean# type: ignore[union-attr]# Update main branch & delete the othersself.branches=repclassRepVGG(nn.Sequential):"""Implements a reparametrized version of VGG as described in `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: num_blocks: list of number of blocks per stage planes: list of output channels of each stage width_multiplier: multiplier for the output channels of all stages apart from the last final_width_multiplier: multiplier for the output channels of the last stage num_classes: number of output classes in_channels: number of input channels act_layer: the activation layer to use norm_layer: the normalization layer to use """def__init__(self,num_blocks:List[int],planes:List[int],width_multiplier:float,final_width_multiplier:float,num_classes:int=10,in_channels:int=3,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,)->None:ifnorm_layerisNone:norm_layer=nn.BatchNorm2difact_layerisNone:act_layer=nn.ReLU(inplace=True)iflen(num_blocks)!=len(planes):raiseAssertionError("the length of `num_blocks` and `planes` are expected to be the same")_stages:List[nn.Sequential]=[]# Assign the width multiplierschans=[in_channels,int(min(1,width_multiplier)*planes[0])]chans.extend([int(width_multiplier*chan)forchaninplanes[1:-1]])chans.append(int(final_width_multiplier*planes[-1]))# Build the layersfornb_blocks,in_chan,out_chaninzip(num_blocks,chans[:-1],chans[1:]):_layers=[RepBlock(in_chan,out_chan,2,False,act_layer,norm_layer)]_layers.extend([RepBlock(out_chan,out_chan,1,True,act_layer,norm_layer)for_inrange(nb_blocks)])_stages.append(nn.Sequential(*_layers))super().__init__(OrderedDict([('features',nn.Sequential(*_stages)),('pool',GlobalAvgPool2d(flatten=True)),('head',nn.Linear(chans[-1],num_classes))]))defreparametrize(self)->None:"""Reparametrize the block by fusing convolutions and BN in each branch, then fusing all branches"""self.features:nn.Sequentialforstageinself.features:forblockinstage:block.reparametrize()def_repvgg(arch:str,pretrained:bool,progress:bool,num_blocks:List[int],out_chans:List[int],a:float,b:float,**kwargs:Any)->RepVGG:# Build the modelmodel=RepVGG(num_blocks,[64,64,128,256,512],a,b,**kwargs)model.default_cfg=default_cfgs[arch]# type: ignore[assignment]# Load pretrained parametersifpretrained:load_pretrained_params(model,default_cfgs[arch]['url'],progress)returnmodel
[docs]defrepvgg_a0(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-A0 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_a0',pretrained,progress,[1,2,4,14,1],[64,64,128,256,512],.75,2.5,**kwargs)
[docs]defrepvgg_a1(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-A1 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_a1',pretrained,progress,[1,2,4,14,1],[64,64,128,256,512],1,2.5,**kwargs)
[docs]defrepvgg_a2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-A2 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_a2',pretrained,progress,[1,2,4,14,1],[64,64,128,256,512],1.5,2.75,**kwargs)
[docs]defrepvgg_b0(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-B0 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_b0',pretrained,progress,[1,4,6,16,1],[64,64,128,256,512],1,2.5,**kwargs)
[docs]defrepvgg_b1(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-B1 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_b1',pretrained,progress,[1,4,6,16,1],[64,64,128,256,512],2,4,**kwargs)
[docs]defrepvgg_b2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-B2 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_b2',pretrained,progress,[1,4,6,16,1],[64,64,128,256,512],2.5,5,**kwargs)
[docs]defrepvgg_b3(pretrained:bool=False,progress:bool=True,**kwargs:Any)->RepVGG:"""RepVGG-B3 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_repvgg('repvgg_b3',pretrained,progress,[1,4,6,16,1],[64,64,128,256,512],3,5,**kwargs)