Source code for holocron.nn.modules.conv

# -*- coding: utf-8 -*-

'''
Convolutional modules
'''

import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.nn.functional import pad
from .. import functional as F

__all__ = ['NormConv2d', 'Add2d', 'SlimConv2d']


class _NormConvNd(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding,
                 groups, bias, padding_mode, normalize_slices=False, eps=1e-14):
        super().__init__(in_channels, out_channels, kernel_size, stride,
                         padding, dilation, transposed, output_padding,
                         groups, bias, padding_mode)
        self.normalize_slices = normalize_slices
        self.eps = eps


[docs] class NormConv2d(_NormConvNd): """Implements the normalized convolution module from `"Normalized Convolutional Neural Network" <https://arxiv.org/pdf/2005.05274v2.pdf>`_. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, H, W)` and output :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: out(N_i, C_{out_j}) = bias(C_{out_j}) + \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star \\frac{input(N_i, k) - \mu(N_i, k)}{\sqrt{\sigma^2(N_i, k) + \epsilon}} where :math:`\star` is the valid 2D cross-correlation operator, :math:`\mu(N_i, k)` and :math:`\sigma²(N_i, k)` are the mean and variance of :math:`input(N_i, k)` over all slices, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` eps (float, optional): a value added to the denominator for numerical stability. Default: 1e-14 """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', eps=1e-14): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, False, eps) def forward(self, input): if self.padding_mode != 'zeros': return F.norm_conv2d(pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), self.weight, self.bias, self.stride, _pair(0), self.dilation, self.groups, self.eps) return F.norm_conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.eps)
[docs] class Add2d(_NormConvNd): """Implements the adder module from `"AdderNet: Do We Really Need Multiplications in Deep Learning?" <https://arxiv.org/pdf/1912.13200.pdf>`_. In the simplest case, the output value of the layer at position :math:`(m, n)` in channel :math:`c` with filter F of spatial size :math:`(d, d)`, intput size :math:`(C_{in}, H, W)` and output :math:`(C_{out}, H, W)` can be precisely described as: .. math:: out(m, n, c) = - \\sum\\limits_{i=0}^d \\sum\\limits_{j=0}^d \\sum\\limits_{k=0}^{C_{in}} |X(m + i, n + j, k) - F(i, j, k, c)| where :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` normalize_slices (bool, optional): whether slices should be normalized before performing cross-correlation. Default: False eps (float, optional): a value added to the denominator for numerical stability. Default: 1e-14 """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', normalize_slices=False, eps=1e-14): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, normalize_slices, eps) def forward(self, input): if self.padding_mode != 'zeros': return F.add2d(pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), self.weight, self.bias, self.stride, _pair(0), self.dilation, self.groups, self.normalize_slices, self.eps) return F.add2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.normalize_slices, self.eps)
[docs] class SlimConv2d(nn.Module): """Implements the convolution module from `"SlimConv: Reducing Channel Redundancy in Convolutional Neural Networks by Weights Flipping" <https://arxiv.org/pdf/2003.07469.pdf>`_. First, we compute channel-wise weights as follows: .. math:: z(c) = \\frac{1}{H \\cdot W} \\sum\\limits_{i=1}^H \\sum\\limits_{j=1}^W X_{c,i,j} where :math:`X \\in \\mathbb{R}^{C \\times H \\times W}` is the input tensor, :math:`H` is height in pixels, and :math:`W` is width in pixels. .. math:: w = \\sigma(F_{fc2}(\\delta(F_{fc1}(z)))) where :math:`z \\in \\mathbb{R}^{C}` contains channel-wise statistics, :math:`\\sigma` refers to the sigmoid function, :math:`\\delta` refers to the ReLU function, :math:`F_{fc1}` is a convolution operation with kernel of size :math:`(1, 1)` with :math:`max(C/r, L)` output channels followed by batch normalization, and :math:`F_{fc2}` is a plain convolution operation with kernel of size :math:`(1, 1)` with :math:`C` output channels. We then proceed with reconstructing and transforming both pathways: .. math:: X_{top} = X \\odot w .. math:: X_{bot} = X \\odot \\check{w} where :math:`\\odot` refers to the element-wise multiplication and :math:`\\check{w}` is the channel-wise reverse-flip of :math:`w`. .. math:: T_{top} = F_{top}(X_{top}^{(1)} + X_{top}^{(2)}) .. math:: T_{bot} = F_{bot}(X_{bot}^{(1)} + X_{bot}^{(2)}) where :math:`X^{(1)}` and :math:`X^{(2)}` are the channel-wise first and second halves of :math:`X`, :math:`F_{top}` is a convolution of kernel size :math:`(3, 3)`, and :math:`F_{bot}` is a convolution of kernel size :math:`(1, 1)` reducing channels by half, followed by a convolution of kernel size :math:`(3, 3)`. Finally we fuse both pathways to yield the output: .. math:: Y = T_{top} \\oplus T_{bot} where :math:`\\oplus` is the channel-wise concatenation. Args: in_channels (int): Number of channels in the input image kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` r (int, optional): squeezing divider. Default: 32 L (int, optional): minimum squeezed channels. Default: 8 """ def __init__(self, in_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', r=32, L=2): super().__init__() self.fc1 = nn.Conv2d(in_channels, max(in_channels // r, L), 1) self.bn = nn.BatchNorm2d(max(in_channels // r, L)) self.fc2 = nn.Conv2d(max(in_channels // r, L), in_channels, 1) self.conv_top = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size, stride, padding, dilation, groups, bias, padding_mode) self.conv_bot1 = nn.Conv2d(in_channels // 2, in_channels // 4, 1) self.conv_bot2 = nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size, stride, padding, dilation, groups, bias, padding_mode) def forward(self, x): # Channel-wise weights z = x.mean(dim=(2, 3), keepdims=True) z = self.bn(self.fc1(z)) z = self.fc2(torch.relu(z)) w = torch.sigmoid(z) # Compression X_w = x * w X_top = X_w[:, :x.shape[1] // 2] + X_w[:, x.shape[1] // 2:] X_w = x * w.flip(dims=(1,)) X_bot = X_w[:, :x.shape[1] // 2] + X_w[:, x.shape[1] // 2:] # Transform X_top = self.conv_top(X_top) X_bot = self.conv_bot2(self.conv_bot1(X_bot)) # Fuse return torch.cat((X_top, X_bot), dim=1)