Source code for holocron.nn.modules.attention
import torch
import torch.nn as nn
__all__ = ['SAM']
[docs]
class SAM(nn.Module):
"""SAM layer from `"CBAM: Convolutional Block Attention Module" <https://arxiv.org/pdf/1807.06521.pdf>`_
modified in `"YOLOv4: Optimal Speed and Accuracy of Object Detection" <https://arxiv.org/pdf/2004.10934.pdf>`_.
Args:
in_channels (int): input channels
"""
def __init__(self, in_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, 1, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(self.conv(x))