Source code for holocron.models.classification.rexnet
# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License version 2.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.fromcollectionsimportOrderedDictfrommathimportceilfromtypingimportAny,Callable,Dict,Optionalimporttorch.nnasnnfromholocron.nnimportGlobalAvgPool2d,initfrom..presetsimportIMAGENETfrom..utilsimportconv_sequence,load_pretrained_params__all__=['SEBlock','ReXBlock','ReXNet','rexnet1_0x','rexnet1_3x','rexnet1_5x','rexnet2_0x','rexnet2_2x']default_cfgs:Dict[str,Dict[str,Any]]={'rexnet1_0x':{**IMAGENET,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_0x_224-ab7b9733.pth',},'rexnet1_3x':{**IMAGENET,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_3x_224-95479104.pth',},'rexnet1_5x':{**IMAGENET,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet1_5x_224-c42a16ac.pth',},'rexnet2_0x':{**IMAGENET,'input_shape':(3,224,224),'url':'https://github.com/frgfm/Holocron/releases/download/v0.1.2/rexnet2_0x_224-c8802402.pth',},'rexnet2_2x':{**IMAGENET,'input_shape':(3,224,224),'url':None,},}classSEBlock(nn.Module):def__init__(self,channels:int,se_ratio:int=12,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()self.pool=GlobalAvgPool2d(flatten=False)self.conv=nn.Sequential(*conv_sequence(channels,channels//se_ratio,act_layer,norm_layer,drop_layer,kernel_size=1,stride=1,bias=(norm_layerisNone)),*conv_sequence(channels//se_ratio,channels,nn.Sigmoid(),None,drop_layer,kernel_size=1,stride=1))defforward(self,x):y=self.pool(x)y=self.conv(y)returnx*yclassReXBlock(nn.Module):def__init__(self,in_channels:int,channels:int,t:int,stride:int,use_se:bool=True,se_ratio:int=12,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifact_layerisNone:act_layer=nn.ReLU6(inplace=True)ifnorm_layerisNone:norm_layer=nn.BatchNorm2dself.use_shortcut=stride==1andin_channels<=channelsself.in_channels=in_channelsself.out_channels=channels_layers=[]ift!=1:dw_channels=in_channels*t_layers.extend(conv_sequence(in_channels,dw_channels,nn.SiLU(inplace=True),norm_layer,drop_layer,kernel_size=1,stride=1,bias=(norm_layerisNone)))else:dw_channels=in_channels_layers.extend(conv_sequence(dw_channels,dw_channels,None,norm_layer,drop_layer,kernel_size=3,stride=stride,padding=1,bias=(norm_layerisNone),groups=dw_channels))ifuse_se:_layers.append(SEBlock(dw_channels,se_ratio,act_layer,norm_layer,drop_layer))_layers.append(act_layer)_layers.extend(conv_sequence(dw_channels,channels,None,norm_layer,drop_layer,kernel_size=1,stride=1,bias=(norm_layerisNone)))self.conv=nn.Sequential(*_layers)defforward(self,x):out=self.conv(x)ifself.use_shortcut:out[:,:self.in_channels]+=xreturnoutclassReXNet(nn.Sequential):def__init__(self,width_mult:float=1.0,depth_mult:float=1.0,num_classes:int=1000,in_channels:int=3,in_planes:int=16,final_planes:int=180,use_se:bool=True,se_ratio:int=12,dropout_ratio:float=0.2,bn_momentum:float=0.9,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,)->None:"""Mostly adapted from https://github.com/clovaai/rexnet/blob/master/rexnetv1.py"""super().__init__()ifact_layerisNone:act_layer=nn.SiLU(inplace=True)ifnorm_layerisNone:norm_layer=nn.BatchNorm2dnum_blocks=[1,2,2,3,3,5]strides=[1,2,2,2,1,2]num_blocks=[ceil(element*depth_mult)forelementinnum_blocks]strides=sum([[element]+[1]*(num_blocks[idx]-1)foridx,elementinenumerate(strides)],[])depth=sum(num_blocks)stem_channel=32/width_multifwidth_mult<1.0else32inplanes=in_planes/width_multifwidth_mult<1.0elsein_planes# The following channel configuration is a simple instance to make each layer become an expand layerchans=[int(round(width_mult*stem_channel))]chans.extend([int(round(width_mult*(inplanes+idx*final_planes/depth)))foridxinrange(depth)])ses=[False]*(num_blocks[0]+num_blocks[1])+[use_se]*sum(num_blocks[2:])_layers=conv_sequence(in_channels,chans[0],act_layer,norm_layer,drop_layer,kernel_size=3,stride=2,padding=1,bias=(norm_layerisNone))t=1forin_c,c,s,seinzip(chans[:-1],chans[1:],strides,ses):_layers.append(ReXBlock(in_channels=in_c,channels=c,t=t,stride=s,use_se=se,se_ratio=se_ratio))t=6pen_channels=int(width_mult*1280)_layers.extend(conv_sequence(chans[-1],pen_channels,act_layer,norm_layer,drop_layer,kernel_size=1,stride=1,padding=0,bias=(norm_layerisNone)))super().__init__(OrderedDict([('features',nn.Sequential(*_layers)),('pool',GlobalAvgPool2d(flatten=True)),('head',nn.Sequential(nn.Dropout(dropout_ratio),nn.Linear(pen_channels,num_classes)))]))# Init all layersinit.init_module(self,nonlinearity='relu')def_rexnet(arch:str,pretrained:bool,progress:bool,width_mult:float,depth_mult:float,**kwargs:Any)->ReXNet:# Build the modelmodel=ReXNet(width_mult,depth_mult,**kwargs)model.default_cfg=default_cfgs[arch]# type: ignore[assignment]# Load pretrained parametersifpretrained:load_pretrained_params(model,default_cfgs[arch]['url'],progress)returnmodel
[docs]defrexnet1_0x(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ReXNet:"""ReXNet-1.0x from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" <https://arxiv.org/pdf/2007.00992.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_rexnet('rexnet1_0x',pretrained,progress,1,1,**kwargs)
[docs]defrexnet1_3x(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ReXNet:"""ReXNet-1.3x from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" <https://arxiv.org/pdf/2007.00992.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_rexnet('rexnet1_3x',pretrained,progress,1.3,1,**kwargs)
[docs]defrexnet1_5x(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ReXNet:"""ReXNet-1.5x from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" <https://arxiv.org/pdf/2007.00992.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_rexnet('rexnet1_5x',pretrained,progress,1.5,1,**kwargs)
[docs]defrexnet2_0x(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ReXNet:"""ReXNet-2.0x from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" <https://arxiv.org/pdf/2007.00992.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_rexnet('rexnet2_0x',pretrained,progress,2,1,**kwargs)
[docs]defrexnet2_2x(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ReXNet:"""ReXNet-2.2x from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" <https://arxiv.org/pdf/2007.00992.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """return_rexnet('rexnet2_2x',pretrained,progress,2.2,1,**kwargs)