# Copyright (C) 2019-2022, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportOptionalimporttorchimporttorch.nn.functionalasFfromtorchimporteinsum,nn__all__=["LambdaLayer"]
[docs]classLambdaLayer(nn.Module):"""Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention" <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains' <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`_. .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/lambdalayer.png :align: center Args: in_channels (int): input channels out_channels (int, optional): output channels dim_k (int): key dimension n (int, optional): number of input pixels r (int, optional): receptive field for relative positional encoding num_heads (int, optional): number of attention heads dim_u (int, optional): intra-depth dimension """def__init__(self,in_channels:int,out_channels:int,dim_k:int,n:Optional[int]=None,r:Optional[int]=None,num_heads:int=4,dim_u:int=1,)->None:super().__init__()self.u=dim_uself.num_heads=num_headsifout_channels%num_heads!=0:raiseAssertionError("values dimension must be divisible by number of heads for multi-head query")dim_v=out_channels//num_heads# Project input and context to get queries, keys & valuesself.to_q=nn.Conv2d(in_channels,dim_k*num_heads,1,bias=False)self.to_k=nn.Conv2d(in_channels,dim_k*dim_u,1,bias=False)self.to_v=nn.Conv2d(in_channels,dim_v*dim_u,1,bias=False)self.norm_q=nn.BatchNorm2d(dim_k*num_heads)self.norm_v=nn.BatchNorm2d(dim_v*dim_u)self.local_contexts=risnotNoneifrisnotNone:ifr%2!=1:raiseAssertionError("Receptive kernel size should be odd")self.padding=r//2self.R=nn.Parameter(torch.randn(dim_k,dim_u,1,r,r))# type: ignore[attr-defined]else:ifnisNone:raiseAssertionError("You must specify the total sequence length (h x w)")self.pos_emb=nn.Parameter(torch.randn(n,n,dim_k,dim_u))# type: ignore[attr-defined]defforward(self,x:torch.Tensor)->torch.Tensor:b,_,h,w=x.shape# Project inputs & context to retrieve queries, keys and valuesq=self.to_q(x)k=self.to_k(x)v=self.to_v(x)# Normalize queries & valuesq=self.norm_q(q)v=self.norm_v(v)# B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)q=q.reshape(b,self.num_heads,-1,h*w)# B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)k=k.reshape(b,-1,self.u,h*w).permute(0,2,1,3)# B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)v=v.reshape(b,-1,self.u,h*w).permute(0,2,1,3)# Normalized keysk=k.softmax(dim=-1)# Content functionλc=einsum("b u k m, b u v m -> b k v",k,v)Yc=einsum("b h k n, b k v -> b n h v",q,λc)# Position functionifself.local_contexts:# B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x Wv=v.reshape(b,self.u,v.shape[2],h,w)λp=F.conv3d(v,self.R,padding=(0,self.padding,self.padding))Yp=einsum("b h k n, b k v n -> b n h v",q,λp.flatten(3))else:λp=einsum("n m k u, b u v m -> b n k v",self.pos_emb,v)Yp=einsum("b h k n, b n k v -> b n h v",q,λp)Y=Yc+Yp# B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x Wout=Y.permute(0,2,3,1).reshape(b,self.num_heads*v.shape[2],h,w)returnout