# Copyright (C) 2019-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportCallable,Dict,Listimportnumpyasnpimporttorchimporttorch.nnasnnfromtorchimportTensorfrom..importfunctionalasF__all__=["SPP","BlurPool2d","ConcatDownsample2d","ConcatDownsample2dJit","GlobalAvgPool2d","GlobalMaxPool2d","ZPool",]
[docs]classConcatDownsample2d(nn.Module):"""Implements a loss-less downsampling operation described in `"YOLO9000: Better, Faster, Stronger" <https://pjreddie.com/media/files/papers/YOLO9000.pdf>`_ by stacking adjacent information on the channel dimension. Args: scale_factor (int): spatial scaling factor """def__init__(self,scale_factor:int)->None:super().__init__()self.scale_factor=scale_factordefforward(self,x:Tensor)->Tensor:returnF.concat_downsample2d(x,self.scale_factor)
@torch.jit.scriptclassConcatDownsample2dJit(object):"""Implements a loss-less downsampling operation described in `"YOLO9000: Better, Faster, Stronger" <https://pjreddie.com/media/files/papers/YOLO9000.pdf>`_ by stacking adjacent information on the channel dimension. Args: scale_factor (int): spatial scaling factor """def__init__(self,scale_factor:int)->None:self.scale_factor=scale_factordef__call__(self,x:Tensor)->Tensor:returnF.concat_downsample2d(x,self.scale_factor)
[docs]classGlobalAvgPool2d(nn.Module):"""Fast implementation of global average pooling from `"TResNet: High Performance GPU-Dedicated Architecture" <https://arxiv.org/pdf/2003.13630.pdf>`_ Args: flatten (bool, optional): whether spatial dimensions should be squeezed """def__init__(self,flatten:bool=False)->None:super().__init__()self.flatten=flattendefforward(self,x:Tensor)->Tensor:ifself.flatten:in_size=x.size()returnx.view((in_size[0],in_size[1],-1)).mean(dim=2)returnx.view(x.size(0),x.size(1),-1).mean(-1).view(x.size(0),x.size(1),1,1)defextra_repr(self)->str:return"flatten=True"ifself.flattenelse""
[docs]classGlobalMaxPool2d(nn.Module):"""Fast implementation of global max pooling from `"TResNet: High Performance GPU-Dedicated Architecture" <https://arxiv.org/pdf/2003.13630.pdf>`_ Args: flatten (bool, optional): whether spatial dimensions should be squeezed """def__init__(self,flatten:bool=False)->None:super().__init__()self.flatten=flattendefforward(self,x:Tensor)->Tensor:ifself.flatten:in_size=x.size()returnx.view((in_size[0],in_size[1],-1)).max(dim=2).valuesreturnx.view(x.size(0),x.size(1),-1).max(-1).values.view(x.size(0),x.size(1),1,1)defextra_repr(self)->str:return"flatten=True"ifself.flattenelse""
[docs]classBlurPool2d(nn.Module):"""Ross Wightman's `implementation <https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/blur_pool.py>`_ of blur pooling module as described in `"Making Convolutional Networks Shift-Invariant Again" <https://arxiv.org/pdf/1904.11486.pdf>`_. .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/blurpool.png :align: center Args: channels (int): Number of input channels kernel_size (int, optional): binomial filter size for blurring. currently supports 3 (default) and 5. stride (int, optional): downsampling filter stride Returns: torch.Tensor: the transformed tensor. """def__init__(self,channels:int,kernel_size:int=3,stride:int=2)->None:super().__init__()self.channels=channelsifkernel_size<=1:raiseAssertionErrorself.kernel_size=kernel_sizeself.stride=stridepad_size=[get_padding(kernel_size,stride,dilation=1)]*4self.padding=nn.ReflectionPad2d(pad_size)# type: ignore[arg-type]self._coeffs=torch.tensor((np.poly1d((0.5,0.5))**(self.kernel_size-1)).coeffs)# for torchscript compatself.kernel:Dict[str,Tensor]={}# lazy init by device for DataParallel compatdef_create_filter(self,like:Tensor)->Tensor:blur_filter=(self._coeffs[:,None]*self._coeffs[None,:]).to(dtype=like.dtype,device=like.device)returnblur_filter[None,None,:,:].repeat(self.channels,1,1,1)def_apply(self,fn:Callable[[nn.Module],None])->None:# type: ignore[override]# override nn.Module _apply, reset filter cache if usedself.kernel={}super()._apply(fn)defforward(self,input_tensor:Tensor)->Tensor:blur_filter=self.kernel.get(str(input_tensor.device),self._create_filter(input_tensor))returnnn.functional.conv2d(self.padding(input_tensor),blur_filter,stride=self.stride,groups=input_tensor.shape[1])defextra_repr(self)->str:returnf"{self.channels}, kernel_size={self.kernel_size}, stride={self.stride}"
[docs]classSPP(nn.ModuleList):"""SPP layer from `"Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition" <https://arxiv.org/pdf/1406.4729.pdf>`_. Args: kernel_sizes (list<int>): kernel sizes of each pooling """def__init__(self,kernel_sizes:List[int])->None:super().__init__([nn.MaxPool2d(k_size,stride=1,padding=k_size//2)fork_sizeinkernel_sizes])defforward(self,x:Tensor)->Tensor:feats=[x]+[pool_layer(x)forpool_layerinself]returntorch.cat(feats,dim=1)
[docs]classZPool(nn.Module):"""Z-pool layer from `"Rotate to Attend: Convolutional Triplet Attention Module" <https://arxiv.org/pdf/2010.03045.pdf>`_. Args: dim: dimension to pool """def__init__(self,dim:int=1)->None:super().__init__()self.dim=dimdefforward(self,x:Tensor)->Tensor:returnF.z_pool(x,self.dim)