Source code for holocron.models.classification.sknet

# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn

from holocron.nn import GlobalAvgPool2d

from ..checkpoints import Checkpoint, _handle_legacy_pretrained
from ..presets import IMAGENETTE
from ..utils import _checkpoint, _configure_model, conv_sequence
from .resnet import ResNet, _ResBlock

__all__ = ["SKBottleneck", "SKConv2d", "SKNet50_Checkpoint", "SoftAttentionLayer", "sknet50", "sknet101", "sknet152"]


default_cfgs: Dict[str, Dict[str, Any]] = {
    "sknet50": {
        **IMAGENETTE.__dict__,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/sknet50_224-5d2160f2.pth",
    },
    "sknet101": {
        **IMAGENETTE.__dict__,
        "input_shape": (3, 224, 224),
        "url": None,
    },
    "sknet152": {
        **IMAGENETTE.__dict__,
        "input_shape": (3, 224, 224),
        "url": None,
    },
}


class SoftAttentionLayer(nn.Sequential):
    def __init__(
        self,
        channels: int,
        sa_ratio: int = 16,
        out_multiplier: int = 1,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__(
            GlobalAvgPool2d(flatten=False),
            *conv_sequence(
                channels,
                max(channels // sa_ratio, 32),
                act_layer,
                norm_layer,
                drop_layer,
                kernel_size=1,
                stride=1,
                bias=(norm_layer is None),
            ),
            *conv_sequence(
                max(channels // sa_ratio, 32),
                channels * out_multiplier,
                nn.Sigmoid(),
                None,
                drop_layer,
                kernel_size=1,
                stride=1,
            ),
        )


class SKConv2d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        m: int = 2,
        sa_ratio: int = 16,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.path_convs = nn.ModuleList([
            nn.Sequential(
                *conv_sequence(
                    in_channels,
                    out_channels,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    kernel_size=3,
                    bias=(norm_layer is None),
                    dilation=idx + 1,
                    padding=idx + 1,
                    **kwargs,
                )
            )
            for idx in range(m)
        ])
        self.sa = SoftAttentionLayer(out_channels, sa_ratio, m, act_layer, norm_layer, drop_layer)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        paths = torch.stack([path_conv(x) for path_conv in self.path_convs], dim=1)

        b, m, c = paths.shape[:3]
        z = self.sa(paths.sum(dim=1)).view(b, m, c, 1, 1)
        attention_factors = torch.softmax(z, dim=1)
        out = (attention_factors * paths).sum(dim=1)

        return out


class SKBottleneck(_ResBlock):
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 32,
        base_width: int = 64,
        dilation: int = 1,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs: Any,
    ) -> None:
        width = int(planes * (base_width / 64.0)) * groups
        super().__init__(
            [
                *conv_sequence(
                    inplanes,
                    width,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=1,
                    stride=1,
                    bias=(norm_layer is None),
                    **kwargs,
                ),
                SKConv2d(width, width, 2, 16, act_layer, norm_layer, drop_layer, groups=groups, stride=stride),
                *conv_sequence(
                    width,
                    planes * self.expansion,
                    None,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=1,
                    stride=1,
                    bias=(norm_layer is None),
                    **kwargs,
                ),
            ],
            downsample,
            act_layer,
        )


def _sknet(
    checkpoint: Union[Checkpoint, None],
    progress: bool,
    num_blocks: List[int],
    out_chans: List[int],
    **kwargs: Any,
) -> ResNet:
    # Build the model
    model = ResNet(SKBottleneck, num_blocks, out_chans, **kwargs)  # type: ignore[arg-type]
    return _configure_model(model, checkpoint, progress=progress)


[docs] class SKNet50_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="sknet50", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/sknet50_224-e2349031.pth", acc1=0.9437, acc5=0.9954, sha256="e2349031c838a4661cd729dbc7825605c9e0c966bd89bbcc9b39f0e324894d1f", size=141253623, num_params=35224394, commit="6e32c5b578711a2ef3731a8f8c61760ed9f03e58", train_args=( "./imagenette2-320/ --arch sknet50 --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def sknet50( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """SKNet-50 from `"Selective Kernel Networks" <https://arxiv.org/pdf/1903.06586.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _sknet Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.SKNet50_Checkpoint :members: """ checkpoint = _handle_legacy_pretrained( pretrained, checkpoint, SKNet50_Checkpoint.DEFAULT.value, ) return _sknet(checkpoint, progress, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)
[docs] def sknet101( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """SKNet-101 from `"Selective Kernel Networks" <https://arxiv.org/pdf/1903.06586.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _sknet Returns: torch.nn.Module: classification model """ return _sknet(checkpoint if pretrained else None, progress, [3, 4, 23, 3], [64, 128, 256, 512], **kwargs)
[docs] def sknet152( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """SKNet-152 from `"Selective Kernel Networks" <https://arxiv.org/pdf/1903.06586.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _sknet Returns: torch.nn.Module: classification model """ return _sknet(checkpoint if pretrained else None, progress, [3, 8, 86, 3], [64, 128, 256, 512], **kwargs)