# Copyright (C) 2019-2024, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportcastimporttorchimporttorch.nnasnnfromtorchimportTensorfrom.downsampleimportZPool__all__=["SAM","TripletAttention"]
[docs]classSAM(nn.Module):"""SAM layer from `"CBAM: Convolutional Block Attention Module" <https://arxiv.org/pdf/1807.06521.pdf>`_ modified in `"YOLOv4: Optimal Speed and Accuracy of Object Detection" <https://arxiv.org/pdf/2004.10934.pdf>`_. Args: in_channels (int): input channels """def__init__(self,in_channels:int)->None:super().__init__()self.conv=nn.Conv2d(in_channels,1,1)defforward(self,x:Tensor)->Tensor:returnx*torch.sigmoid(self.conv(x))
classDimAttention(nn.Module):"""Attention layer across a specific dimension Args: dim: dimension to compute attention on """def__init__(self,dim:int)->None:super().__init__()self.compress=nn.Sequential(ZPool(dim=1),nn.Conv2d(2,1,kernel_size=7,stride=1,padding=3,bias=False),nn.BatchNorm2d(1,eps=1e-5,momentum=0.01),nn.Sigmoid(),)self.dim=dimdefforward(self,x:Tensor)->Tensor:ifself.dim!=1:x=x.transpose(self.dim,1).contiguous()out=cast(Tensor,x*self.compress(x))ifself.dim!=1:out=out.transpose(self.dim,1).contiguous()returnout
[docs]classTripletAttention(nn.Module):"""Triplet attention layer from `"Rotate to Attend: Convolutional Triplet Attention Module" <https://arxiv.org/pdf/2010.03045.pdf>`_. This implementation is based on the `one <https://github.com/LandskapeAI/triplet-attention/blob/master/MODELS/triplet_attention.py>`_ from the paper's authors. """def__init__(self)->None:super().__init__()self.c_branch=DimAttention(dim=1)self.h_branch=DimAttention(dim=2)self.w_branch=DimAttention(dim=3)defforward(self,x:Tensor)->Tensor:x_c=cast(Tensor,self.c_branch(x))x_h=cast(Tensor,self.h_branch(x))x_w=cast(Tensor,self.w_branch(x))return(x_c+x_h+x_w)/3