Source code for holocron.models.classification.convnext
# Copyright (C) 2022-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromcollectionsimportOrderedDictfromenumimportEnumfromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Unionimporttorchimporttorch.nnasnnfromtorchimportTensorfromtorchvision.ops.stochastic_depthimportStochasticDepthfromholocron.nnimportGlobalAvgPool2dfrom..checkpointsimportCheckpoint,_handle_legacy_pretrainedfrom..utilsimport_checkpoint,_configure_model,conv_sequencefrom.resnetimport_ResBlock__all__=["ConvNeXt","ConvNeXt_Atto_Checkpoint","convnext_atto","convnext_base","convnext_femto","convnext_large","convnext_nano","convnext_pico","convnext_small","convnext_tiny","convnext_xl",]classLayerNorm2d(nn.LayerNorm):"""Compatibility wrapper of LayerNorm on 2D tensors"""defforward(self,x:Tensor)->Tensor:returnsuper().forward(x.permute(0,2,3,1)).permute(0,3,1,2)classLayerScale(nn.Module):"""Learnable channel-wise scaling"""def__init__(self,chans:int,scale:float=1e-6)->None:super().__init__()self.register_parameter("weight",nn.Parameter(scale*torch.ones(chans)))defforward(self,x:torch.Tensor)->torch.Tensor:returnx*self.weight.reshape(1,-1,*((1,)*(x.ndim-2)))classBottlenext(_ResBlock):def__init__(self,inplanes:int,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,chan_expansion:int=4,stochastic_depth_prob:float=0.1,layer_scale:float=1e-6,)->None:ifnorm_layerisNone:norm_layer=partial(LayerNorm2d,eps=1e-6)ifact_layerisNone:act_layer=nn.GELU()super().__init__([# Depth-conv (groups = in_channels): spatial awareness*conv_sequence(inplanes,inplanes,None,norm_layer,drop_layer,kernel_size=7,padding=3,stride=1,bias=True,groups=inplanes,),# 1x1 conv: channel awareness*conv_sequence(inplanes,inplanes*chan_expansion,act_layer,None,drop_layer,kernel_size=1,stride=1,bias=True,),# 1x1 conv: channel mapping*conv_sequence(inplanes*chan_expansion,inplanes,None,None,drop_layer,kernel_size=1,stride=1,bias=True,),LayerScale(inplanes,layer_scale),StochasticDepth(stochastic_depth_prob,"row"),],None,None,)classConvNeXt(nn.Sequential):def__init__(self,num_blocks:List[int],planes:List[int],num_classes:int=10,in_channels:int=3,conv_layer:Optional[Callable[...,nn.Module]]=None,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,stochastic_depth_prob:float=0.0,)->None:ifconv_layerisNone:conv_layer=nn.Conv2difnorm_layerisNone:norm_layer=partial(LayerNorm2d,eps=1e-6)ifact_layerisNone:act_layer=nn.GELU()self.dilation=1# Patchify-like stemlayers=conv_sequence(in_channels,planes[0],None,norm_layer,drop_layer,conv_layer,kernel_size=4,stride=4,padding=0,bias=True,)block_idx=0tot_blocks=sum(num_blocks)for_num_blocks,_planes,_oplanesinzip(num_blocks,planes,planes[1:]+[planes[-1]]):# adjust stochastic depth probability based on the depth of the stage blocksd_probs=[stochastic_depth_prob*(block_idx+_idx)/(tot_blocks-1.0)for_idxinrange(_num_blocks)]stage:List[nn.Module]=[Bottlenext(_planes,act_layer,norm_layer,drop_layer,stochastic_depth_prob=sd_prob)for_idx,sd_probinzip(range(_num_blocks),sd_probs)]if_planes!=_oplanes:stage.append(nn.Sequential(LayerNorm2d(_planes),nn.Conv2d(_planes,_oplanes,kernel_size=2,stride=2),))layers.append(nn.Sequential(*stage))block_idx+=_num_blockssuper().__init__(OrderedDict([("features",nn.Sequential(*layers)),("pool",GlobalAvgPool2d(flatten=True)),("head",nn.Sequential(nn.LayerNorm(planes[-1],eps=1e-6),nn.Linear(planes[-1],num_classes),),),]))# Init all layersforminself.modules():ifisinstance(m,(nn.Conv2d,nn.Linear)):nn.init.trunc_normal_(m.weight,std=0.02)ifm.biasisnotNone:nn.init.zeros_(m.bias)def_convnext(checkpoint:Union[Checkpoint,None],progress:bool,num_blocks:List[int],out_chans:List[int],**kwargs:Any,)->ConvNeXt:# Build the modelmodel=ConvNeXt(num_blocks,out_chans,**kwargs)return_configure_model(model,checkpoint,progress=progress)
[docs]defconvnext_atto(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any,)->ConvNeXt:"""ConvNeXt-Atto variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ConvNeXt_Atto_Checkpoint :members: """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,ConvNeXt_Atto_Checkpoint.DEFAULT.value,)return_convnext(checkpoint,progress,[2,2,6,2],[40,80,160,320],**kwargs)
[docs]defconvnext_femto(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-Femto variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[2,2,6,2],[48,96,192,384],**kwargs)
[docs]defconvnext_pico(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-Pico variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[2,2,6,2],[64,128,256,512],**kwargs)
[docs]defconvnext_nano(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-Nano variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[2,2,8,2],[80,160,320,640],**kwargs)
[docs]defconvnext_tiny(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-T from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[3,3,9,3],[96,192,384,768],**kwargs)
[docs]defconvnext_small(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-S from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[3,3,27,3],[96,192,384,768],**kwargs)
[docs]defconvnext_base(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-B from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[3,3,27,3],[128,256,512,1024],**kwargs)
[docs]defconvnext_large(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-L from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[3,3,27,3],[192,384,768,1536],**kwargs)
[docs]defconvnext_xl(pretrained:bool=False,checkpoint:Union[Checkpoint,None]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt-XL from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """checkpoint=_handle_legacy_pretrained(pretrained,checkpoint,None)return_convnext(checkpoint,progress,[3,3,27,3],[256,512,1024,2048],**kwargs)