Source code for holocron.models.segmentation.unetpp
# Copyright (C) 2020-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License version 2.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.importsysfromtypingimportAny,Callable,Dict,List,Optionalimporttorch.nnasnnfromtorchimportTensorfrom...nn.initimportinit_modulefrom..utilsimportconv_sequence,load_pretrained_paramsfrom.unetimportUpPath,down_path__all__=['UNetp','unetp','UNetpp','unetpp']default_cfgs:Dict[str,Dict[str,Any]]={'unetp':{'arch':'UNetp','layout':[64,128,256,512],'url':None},'unetpp':{'arch':'UNetpp','layout':[64,128,256,512],'url':None},}classUNetp(nn.Module):"""Implements a UNet+ architecture Args: layout: number of channels after each contracting block in_channels: number of channels in the input tensor num_classes: number of output classes act_layer: activation layer norm_layer: normalization layer drop_layer: dropout layer conv_layer: convolutional layer """def__init__(self,layout:List[int],in_channels:int=3,num_classes:int=10,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,conv_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifact_layerisNone:act_layer=nn.ReLU(inplace=True)# Contracting pathself.encoder=nn.ModuleList([])_layout=[in_channels]+layout_pool=Falseforin_chan,out_chaninzip(_layout[:-1],_layout[1:]):self.encoder.append(down_path(in_chan,out_chan,_pool,1,act_layer,norm_layer,drop_layer,conv_layer))_pool=Trueself.bridge=nn.Sequential(nn.MaxPool2d((2,2)),*conv_sequence(layout[-1],2*layout[-1],act_layer,norm_layer,drop_layer,conv_layer,kernel_size=3,padding=1),*conv_sequence(2*layout[-1],layout[-1],act_layer,norm_layer,drop_layer,conv_layer,kernel_size=3,padding=1),)# Expansive pathself.decoder=nn.ModuleList([])_layout=[layout[-1]]+layout[1:][::-1]forleft_chan,up_chan,num_cellsinzip(layout[::-1],_layout,range(1,len(layout)+1)):self.decoder.append(nn.ModuleList([UpPath(left_chan+up_chan,left_chan,True,1,act_layer,norm_layer,drop_layer,conv_layer)for_inrange(num_cells)]))# Classifierself.classifier=nn.Conv2d(layout[0],num_classes,1)init_module(self,'relu')defforward(self,x:Tensor)->Tensor:xs:List[Tensor]=[]# Contracting pathforencoderinself.encoder:xs.append(encoder(xs[-1]iflen(xs)>0elsex))xs.append(self.bridge(xs[-1]))# Nested expansive pathforjinrange(len(self.decoder)):foriinrange(len(xs)-1):up_feat=xs[i+1]if(i+2)<len(xs)elsexs.pop()xs[i]=self.decoder[-1-i][j](xs[i],up_feat)returnself.classifier(xs.pop())classUNetpp(nn.Module):"""Implements a UNet++ architecture Args: layout: number of channels after each contracting block in_channels: number of channels in the input tensor num_classes: number of output classes act_layer: activation layer norm_layer: normalization layer drop_layer: dropout layer conv_layer: convolutional layer """def__init__(self,layout:List[int],in_channels:int=3,num_classes:int=10,act_layer:Optional[nn.Module]=None,norm_layer:Optional[Callable[[int],nn.Module]]=None,drop_layer:Optional[Callable[...,nn.Module]]=None,conv_layer:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifact_layerisNone:act_layer=nn.ReLU(inplace=True)# Contracting pathself.encoder=nn.ModuleList([])_layout=[in_channels]+layout_pool=Falseforin_chan,out_chaninzip(_layout[:-1],_layout[1:]):self.encoder.append(down_path(in_chan,out_chan,_pool,1,act_layer,norm_layer,drop_layer,conv_layer))_pool=Trueself.bridge=nn.Sequential(nn.MaxPool2d((2,2)),*conv_sequence(layout[-1],2*layout[-1],act_layer,norm_layer,drop_layer,conv_layer,kernel_size=3,padding=1),*conv_sequence(2*layout[-1],layout[-1],act_layer,norm_layer,drop_layer,conv_layer,kernel_size=3,padding=1),)# Expansive pathself.decoder=nn.ModuleList([])_layout=[layout[-1]]+layout[1:][::-1]forleft_chan,up_chan,num_cellsinzip(layout[::-1],_layout,range(1,len(layout)+1)):self.decoder.append(nn.ModuleList([UpPath(up_chan+(idx+1)*left_chan,left_chan,True,1,act_layer,norm_layer,drop_layer,conv_layer)foridxinrange(num_cells)]))# Classifierself.classifier=nn.Conv2d(layout[0],num_classes,1)init_module(self,'relu')defforward(self,x:Tensor)->Tensor:xs:List[List[Tensor]]=[]# Contracting pathforencoderinself.encoder:xs.append([encoder(xs[-1][0]iflen(xs)>0elsex)])xs.append([self.bridge(xs[-1][-1])])# Nested expansive pathforjinrange(len(self.decoder)):foriinrange(len(xs)-1):up_feat=xs[i+1][j]if(i+2)<len(xs)elsexs.pop()[-1]xs[i].append(self.decoder[-1-i][j](xs[i],up_feat))# Classifierreturnself.classifier(xs.pop()[-1])def_unet(arch:str,pretrained:bool,progress:bool,**kwargs:Any)->nn.Module:# Retrieve the correct Darknet layout typeunet_type=sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]# Build the modelmodel=unet_type(default_cfgs[arch]['layout'],**kwargs)# Load pretrained parametersifpretrained:load_pretrained_params(model,default_cfgs[arch]['url'],progress)returnmodel
[docs]defunetp(pretrained:bool=False,progress:bool=True,**kwargs:Any)->UNetp:"""UNet+ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation" <https://arxiv.org/pdf/1912.05074.pdf>`_ .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unetp.png :align: center Args: pretrained: If True, returns a model pre-trained on PASCAL VOC2012 progress: If True, displays a progress bar of the download to stderr Returns: semantic segmentation model """return_unet('unetp',pretrained,progress,**kwargs)# type: ignore[return-value]
[docs]defunetpp(pretrained:bool=False,progress:bool=True,**kwargs:Any)->UNetpp:"""UNet++ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation" <https://arxiv.org/pdf/1912.05074.pdf>`_ .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unetpp.png :align: center Args: pretrained: If True, returns a model pre-trained on PASCAL VOC2012 progress: If True, displays a progress bar of the download to stderr Returns: semantic segmentation model """return_unet('unetpp',pretrained,progress,**kwargs)# type: ignore[return-value]