Source code for holocron.nn.modules.lambda_layer

# Copyright (C) 2019-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from typing import Optional

import torch
import torch.nn.functional as F
from torch import einsum, nn

__all__ = ["LambdaLayer"]


[docs] class LambdaLayer(nn.Module): """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention" <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains' <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`_. .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/lambdalayer.png :align: center Args: in_channels (int): input channels out_channels (int, optional): output channels dim_k (int): key dimension n (int, optional): number of input pixels r (int, optional): receptive field for relative positional encoding num_heads (int, optional): number of attention heads dim_u (int, optional): intra-depth dimension """ def __init__( self, in_channels: int, out_channels: int, dim_k: int, n: Optional[int] = None, r: Optional[int] = None, num_heads: int = 4, dim_u: int = 1, ) -> None: super().__init__() self.u = dim_u self.num_heads = num_heads if out_channels % num_heads != 0: raise AssertionError("values dimension must be divisible by number of heads for multi-head query") dim_v = out_channels // num_heads # Project input and context to get queries, keys & values self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False) self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False) self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False) self.norm_q = nn.BatchNorm2d(dim_k * num_heads) self.norm_v = nn.BatchNorm2d(dim_v * dim_u) self.local_contexts = r is not None if r is not None: if r % 2 != 1: raise AssertionError("Receptive kernel size should be odd") self.padding = r // 2 self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r)) else: if n is None: raise AssertionError("You must specify the total sequence length (h x w)") self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u)) def forward(self, x: torch.Tensor) -> torch.Tensor: b, _, h, w = x.shape # Project inputs & context to retrieve queries, keys and values q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) # Normalize queries & values q = self.norm_q(q) v = self.norm_v(v) # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W) q = q.reshape(b, self.num_heads, -1, h * w) # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W) k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3) # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W) v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3) # Normalized keys k = k.softmax(dim=-1) # Content function λc = einsum("b u k m, b u v m -> b k v", k, v) Yc = einsum("b h k n, b k v -> b n h v", q, λc) # Position function if self.local_contexts: # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W v = v.reshape(b, self.u, v.shape[2], h, w) λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding)) Yp = einsum("b h k n, b k v n -> b n h v", q, λp.flatten(3)) else: λp = einsum("n m k u, b u v m -> b n k v", self.pos_emb, v) Yp = einsum("b h k n, b n k v -> b n h v", q, λp) Y = Yc + Yp # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W return Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)