Source code for holocron.models.classification.tridentnet

# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
from torch.nn import functional as F

from ..presets import IMAGENETTE
from ..utils import conv_sequence, load_pretrained_params
from .resnet import ResNet, _ResBlock

__all__ = ["Tridentneck", "tridentnet50"]

default_cfgs: Dict[str, Dict[str, Any]] = {
    "tridentnet50": {
        **IMAGENETTE.__dict__,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.2/tridentnet50_224-98b4ce9c.pth",
    },
}


class TridentConv2d(nn.Conv2d):
    num_branches: int = 3

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        if self.dilation[0] != 1 and self.dilation[0] != self.num_branches:
            raise ValueError(f"expected dilation to either be 1 or {self.num_branches}.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[1] % self.num_branches != 0:
            raise ValueError("expected number of channels of input tensor to be a multiple of `num_branches`.")

        # Dilation for each chunk
        dilations = [1] * self.num_branches if self.dilation[0] == 1 else [1 + idx for idx in range(self.num_branches)]

        # Use shared weight to apply the convolution
        return torch.cat(
            [
                F.conv2d(
                    _x,
                    self.weight,
                    self.bias,
                    self.stride,
                    tuple(dilation * p for p in self.padding),  # type: ignore[misc]
                    (dilation,) * len(self.dilation),
                    self.groups,
                )
                for _x, dilation in zip(torch.chunk(x, self.num_branches, 1), dilations)
            ],
            1,
        )


class Tridentneck(_ResBlock):
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 3,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs: Any,
    ) -> None:
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if act_layer is None:
            act_layer = nn.ReLU(inplace=True)

        width = int(planes * (base_width / 64.0)) * groups
        # Concatenate along the channel axis and enlarge BN to leverage parallelization
        super().__init__(
            [
                *conv_sequence(
                    inplanes,
                    width,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    TridentConv2d,
                    bn_channels=3 * width,
                    kernel_size=1,
                    stride=1,
                    bias=(norm_layer is None),
                    dilation=1,
                    **kwargs,
                ),
                *conv_sequence(
                    width,
                    width,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    TridentConv2d,
                    bn_channels=3 * width,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=groups,
                    bias=(norm_layer is None),
                    dilation=3,
                    **kwargs,
                ),
                *conv_sequence(
                    width,
                    planes * self.expansion,
                    None,
                    norm_layer,
                    drop_layer,
                    TridentConv2d,
                    bn_channels=3 * planes * self.expansion,
                    kernel_size=1,
                    stride=1,
                    bias=(norm_layer is None),
                    dilation=1,
                    **kwargs,
                ),
            ],
            downsample,
            act_layer,
        )


def _tridentnet(
    arch: str,
    pretrained: bool,
    progress: bool,
    num_blocks: List[int],
    out_chans: List[int],
    **kwargs: Any,
) -> ResNet:
    # Build the model
    model = ResNet(Tridentneck, num_blocks, out_chans, num_repeats=3, **kwargs)  # type: ignore[arg-type]
    model.default_cfg = default_cfgs[arch]  # type: ignore[assignment]
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, default_cfgs[arch]["url"], progress)

    return model


[docs] def tridentnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: """TridentNet-50 from `"Scale-Aware Trident Networks for Object Detection" <https://arxiv.org/pdf/1901.01892.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _tridentnet Returns: torch.nn.Module: classification model """ return _tridentnet("tridentnet50", pretrained, progress, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)