holocron.nn

An addition to the torch.nn module of Pytorch to extend the range of neural networks building blocks.

Non-linear activations

class holocron.nn.HardMish(inplace: bool = False)[source]

Implements the Had Mish activation module from “H-Mish”.

This activation is computed as follows:

f(x)=x2min(2,max(0,x+2))
class holocron.nn.NLReLU(inplace: bool = False)[source]

Implements the Natural-Logarithm ReLU activation module from “Natural-Logarithm-Rectified Activation Function in Convolutional Neural Networks”.

This activation is computed as follows:

f(x)=ln(1+βmax(0,x))
Parameters:

inplace (bool) – should the operation be performed inplace

class holocron.nn.FReLU(in_channels: int, kernel_size: int = 3)[source]

Implements the Funnel activation module from “Funnel Activation for Visual Recognition”.

This activation is computed as follows:

f(x)=max(T(x),x)

where the T is the spatial contextual feature extraction. It is a convolution filter of size kernel_size, same padding and groups equal to the number of input channels, followed by a batch normalization.

Parameters:

inplace (bool) – should the operation be performed inplace

Loss functions

class holocron.nn.FocalLoss(gamma: float = 2.0, **kwargs: Any)[source]

Implementation of Focal Loss as described in “Focal Loss for Dense Object Detection”.

While the weighted cross-entropy is described by:

CE(pt)=αtlog(pt)

where αt is the loss weight of class t, and pt is the predicted probability of class t.

the focal loss introduces a modulating factor

FL(pt)=αt(1pt)γlog(pt)

where γ is a positive focusing parameter.

Parameters:
  • gamma (float, optional) – exponent parameter of the focal loss

  • weight (torch.Tensor[K], optional) – class weight for loss computation

  • ignore_index (int, optional) – specifies target value that is ignored and do not contribute to gradient

  • reduction (str, optional) – type of reduction to apply to the final loss

class holocron.nn.MultiLabelCrossEntropy(*args: Any, **kwargs: Any)[source]

Implementation of the cross-entropy loss for multi-label targets

Parameters:
  • weight (torch.Tensor[K], optional) – class weight for loss computation

  • ignore_index (int, optional) – specifies target value that is ignored and do not contribute to gradient

  • reduction (str, optional) – type of reduction to apply to the final loss

class holocron.nn.ComplementCrossEntropy(gamma: float = -1, **kwargs: Any)[source]

Implements the complement cross entropy loss from “Imbalanced Image Classification with Complement Cross Entropy”

Parameters:
  • gamma (float, optional) – smoothing factor

  • weight (torch.Tensor[K], optional) – class weight for loss computation

  • ignore_index (int, optional) – specifies target value that is ignored and do not contribute to gradient

  • reduction (str, optional) – type of reduction to apply to the final loss

class holocron.nn.MutualChannelLoss(weight: float | List[float] | Tensor | None = None, ignore_index: int = -100, reduction: str = 'mean', xi: int = 2, alpha: float = 1)[source]

Implements the mutual channel loss from “The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification”.

Parameters:
  • weight (torch.Tensor[K], optional) – class weight for loss computation

  • ignore_index (int, optional) – specifies target value that is ignored and do not contribute to gradient

  • reduction (str, optional) – type of reduction to apply to the final loss

  • xi (in, optional) – num of features per class

  • alpha (float, optional) – diversity factor

class holocron.nn.DiceLoss(weight: float | List[float] | Tensor | None = None, gamma: float = 1.0, eps: float = 1e-08)[source]

Implements the dice loss from “V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation”

Parameters:
  • weight (torch.Tensor[K], optional) – class weight for loss computation

  • gamma (float, optional) – recall/precision control param

  • eps (float, optional) – small value added to avoid division by zero

class holocron.nn.PolyLoss(*args: Any, eps: float = 2.0, **kwargs: Any)[source]

Implements the Poly1 loss from “PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions”.

Parameters:
  • weight (torch.Tensor[K], optional) – class weight for loss computation

  • eps (float, optional) – epsilon 1 from the paper

  • ignore_index – int = -100,

  • reduction – str = ‘mean’,

Loss wrappers

class holocron.nn.ClassBalancedWrapper(criterion: Module, num_samples: Tensor, beta: float = 0.99)[source]

Implementation of the class-balanced loss as described in “Class-Balanced Loss Based on Effective Number of Samples”.

Given a loss function L, the class-balanced loss is described by:

CB(p,y)=1β1βnyL(p,y)

where p is the predicted probability for class y, ny is the number of training samples for class y, and β is exponential factor.

Parameters:
  • criterion (torch.nn.Module) – loss module

  • num_samples (torch.Tensor[K]) – number of samples for each class

  • beta (float, optional) – rebalancing exponent

Convolution layers

class holocron.nn.NormConv2d(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', eps: float = 1e-14)[source]

Implements the normalized convolution module from “Normalized Convolutional Neural Network”.

In the simplest case, the output value of the layer with input size (N,Cin,H,W) and output (N,Cout,Hout,Wout) can be precisely described as:

out(Ni,Coutj)=bias(Coutj)+k=0Cin1weight(Coutj,k)input(Ni,k)μ(Ni,k)σ2(Ni,k)+ϵ

where is the valid 2D cross-correlation operator, μ(Ni,k) and σ²(Ni,k) are the mean and variance of input(Ni,k) over all slices, N is a batch size, C denotes a number of channels, H is a height of input planes in pixels, and W is width in pixels.

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution. Default: 1

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0

  • dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

  • eps (float, optional) – a value added to the denominator for numerical stability. Default: 1e-14

class holocron.nn.Add2d(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', normalize_slices: bool = False, eps: float = 1e-14)[source]

Implements the adder module from “AdderNet: Do We Really Need Multiplications in Deep Learning?”.

In the simplest case, the output value of the layer at position (m,n) in channel c with filter F of spatial size (d,d), intput size (Cin,H,W) and output (Cout,H,W) can be precisely described as:

out(m,n,c)=i=0dj=0dk=0Cin|X(m+i,n+j,k)F(i,j,k,c)|

where C denotes a number of channels, H is a height of input planes in pixels, and W is width in pixels.

Add2D schema
Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution. Default: 1

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0

  • dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

  • normalize_slices (bool, optional) – whether slices should be normalized before performing cross-correlation. Default: False

  • eps (float, optional) – a value added to the denominator for numerical stability. Default: 1e-14

class holocron.nn.SlimConv2d(in_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', r: int = 32, L: int = 2)[source]

Implements the convolution module from “SlimConv: Reducing Channel Redundancy in Convolutional Neural Networks by Weights Flipping”.

First, we compute channel-wise weights as follows:

z(c)=1HWi=1Hj=1WXc,i,j

where XRC×H×W is the input tensor, H is height in pixels, and W is width in pixels.

w=σ(Ffc2(δ(Ffc1(z))))

where zRC contains channel-wise statistics, σ refers to the sigmoid function, δ refers to the ReLU function, Ffc1 is a convolution operation with kernel of size (1,1) with max(C/r,L) output channels followed by batch normalization, and Ffc2 is a plain convolution operation with kernel of size (1,1) with C output channels.

We then proceed with reconstructing and transforming both pathways:

Xtop=Xw
Xbot=Xwˇ

where refers to the element-wise multiplication and wˇ is the channel-wise reverse-flip of w.

Ttop=Ftop(Xtop(1)+Xtop(2))
Tbot=Fbot(Xbot(1)+Xbot(2))

where X(1) and X(2) are the channel-wise first and second halves of X, Ftop is a convolution of kernel size (3,3), and Fbot is a convolution of kernel size (1,1) reducing channels by half, followed by a convolution of kernel size (3,3).

Finally we fuse both pathways to yield the output:

Y=TtopTbot

where is the channel-wise concatenation.

SlimConv2D schema
Parameters:
  • in_channels (int) – Number of channels in the input image

  • kernel_size (int or tuple) – Size of the convolving kernel

  • stride (int or tuple, optional) – Stride of the convolution. Default: 1

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0

  • dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1

  • groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

  • padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

  • r (int, optional) – squeezing divider. Default: 32

  • L (int, optional) – minimum squeezed channels. Default: 8

class holocron.nn.PyConv2d(in_channels: int, out_channels: int, kernel_size: int, num_levels: int = 2, padding: int = 0, groups: List[int] | None = None, **kwargs: Any)[source]

Implements the convolution module from “Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition”.

PyConv2D schema
Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int) – Size of the convolving kernel

  • num_levels (int, optional) – number of stacks in the pyramid

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0

  • groups (list(int), optional) – Number of blocked connections from input channels to output channels. Default: 1

class holocron.nn.Involution2d(in_channels: int, kernel_size: int, padding: int = 0, stride: int = 1, groups: int = 1, dilation: int = 1, reduction_ratio: float = 1)[source]

Implements the convolution module from “Involution: Inverting the Inherence of Convolution for Visual Recognition”, adapted from the proposed PyTorch implementation in the paper.

Involution2d schema
Parameters:
  • in_channels (int) – Number of channels in the input image

  • kernel_size (int) – Size of the convolving kernel

  • padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0

  • stride – Stride of the convolution. Default: 1

  • groups – Number of blocked connections from input channels to output channels. Default: 1

  • dilation – Spacing between kernel elements. Default: 1

  • reduction_ratio – reduction ratio of the channels to generate the kernel

Regularization layers

class holocron.nn.DropBlock2d(p: float = 0.1, block_size: int = 7, inplace: bool = False)[source]

Implements the DropBlock module from “DropBlock: A regularization method for convolutional networks”

https://github.com/frgfm/Holocron/releases/download/v0.1.3/dropblock.png
Parameters:
  • p (float, optional) – probability of dropping activation value

  • block_size (int, optional) – size of each block that is expended from the sampled mask

  • inplace (bool, optional) – whether the operation should be done inplace

Downsampling

class holocron.nn.ConcatDownsample2d(scale_factor: int)[source]

Implements a loss-less downsampling operation described in “YOLO9000: Better, Faster, Stronger” by stacking adjacent information on the channel dimension.

Parameters:

scale_factor (int) – spatial scaling factor

class holocron.nn.GlobalAvgPool2d(flatten: bool = False)[source]

Fast implementation of global average pooling from “TResNet: High Performance GPU-Dedicated Architecture”

Parameters:

flatten (bool, optional) – whether spatial dimensions should be squeezed

class holocron.nn.GlobalMaxPool2d(flatten: bool = False)[source]

Fast implementation of global max pooling from “TResNet: High Performance GPU-Dedicated Architecture”

Parameters:

flatten (bool, optional) – whether spatial dimensions should be squeezed

class holocron.nn.BlurPool2d(channels: int, kernel_size: int = 3, stride: int = 2)[source]

Ross Wightman’s implementation of blur pooling module as described in “Making Convolutional Networks Shift-Invariant Again”.

https://github.com/frgfm/Holocron/releases/download/v0.1.3/blurpool.png
Parameters:
  • channels (int) – Number of input channels

  • kernel_size (int, optional) – binomial filter size for blurring. currently supports 3 (default) and 5.

  • stride (int, optional) – downsampling filter stride

Returns:

the transformed tensor.

Return type:

torch.Tensor

class holocron.nn.SPP(kernel_sizes: List[int])[source]

SPP layer from “Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition”.

Parameters:

kernel_sizes (list<python:int>) – kernel sizes of each pooling

class holocron.nn.ZPool(dim: int = 1)[source]

Z-pool layer from “Rotate to Attend: Convolutional Triplet Attention Module”.

Parameters:

dim – dimension to pool

Attention

class holocron.nn.SAM(in_channels: int)[source]

SAM layer from “CBAM: Convolutional Block Attention Module” modified in “YOLOv4: Optimal Speed and Accuracy of Object Detection”.

Parameters:

in_channels (int) – input channels

class holocron.nn.LambdaLayer(in_channels: int, out_channels: int, dim_k: int, n: int | None = None, r: int | None = None, num_heads: int = 4, dim_u: int = 1)[source]

Lambda layer from “LambdaNetworks: Modeling long-range interactions without attention”. The implementation was adapted from lucidrains’.

https://github.com/frgfm/Holocron/releases/download/v0.1.3/lambdalayer.png
Parameters:
  • in_channels (int) – input channels

  • out_channels (int, optional) – output channels

  • dim_k (int) – key dimension

  • n (int, optional) – number of input pixels

  • r (int, optional) – receptive field for relative positional encoding

  • num_heads (int, optional) – number of attention heads

  • dim_u (int, optional) – intra-depth dimension

class holocron.nn.TripletAttention[source]

Triplet attention layer from “Rotate to Attend: Convolutional Triplet Attention Module”. This implementation is based on the one from the paper’s authors.