Source code for holocron.models.classification.resnet

# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch.nn as nn
from torch import Tensor

from holocron.nn import GlobalAvgPool2d, init

from ..checkpoints import Checkpoint, _handle_legacy_pretrained
from ..presets import IMAGENET, IMAGENETTE
from ..utils import _checkpoint, _configure_model, conv_sequence

__all__ = [
    "BasicBlock",
    "Bottleneck",
    "ResNeXt50_32x4d_Checkpoint",
    "ResNet",
    "ResNet18_Checkpoint",
    "ResNet34_Checkpoint",
    "ResNet50D_Checkpoint",
    "ResNet50_Checkpoint",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet50d",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x8d",
]


default_cfgs: Dict[str, Dict[str, Any]] = {
    "resnet18": {**IMAGENET.__dict__, "input_shape": (3, 224, 224), "url": None},
    "resnet34": {**IMAGENET.__dict__, "input_shape": (3, 224, 224), "url": None},
    "resnet50": {
        **IMAGENETTE.__dict__,
        "input_shape": (3, 256, 256),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.2/resnet50_256-5e6206e0.pth",
    },
    "resnet101": {**IMAGENET.__dict__, "input_shape": (3, 224, 224), "url": None},
    "resnet152": {**IMAGENET.__dict__, "input_shape": (3, 224, 224), "url": None},
    "resnext50_32x4d": {**IMAGENET.__dict__, "input_shape": (3, 224, 224), "url": None},
    "resnext101_32x8d": {**IMAGENET.__dict__, "input_shape": (3, 224, 224), "url": None},
    "resnet50d": {
        **IMAGENETTE.__dict__,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/resnet50d_224-e315ba9d.pt",
    },
}


class _ResBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self, convs: List[nn.Module], downsample: Optional[nn.Module] = None, act_layer: Optional[nn.Module] = None
    ) -> None:
        super().__init__()

        # Main branch
        self.conv = nn.Sequential(*convs)
        # Shortcut connection
        self.downsample = downsample

        if isinstance(act_layer, nn.Module):
            self.activation = act_layer

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv(x)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        if hasattr(self, "activation"):
            out = self.activation(out)

        return out


class BasicBlock(_ResBlock):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            [
                *conv_sequence(
                    inplanes,
                    planes,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=3,
                    stride=stride,
                    padding=dilation,
                    groups=groups,
                    bias=(norm_layer is None),
                    dilation=dilation,
                    **kwargs,
                ),
                *conv_sequence(
                    planes,
                    planes,
                    None,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=3,
                    stride=1,
                    padding=dilation,
                    groups=groups,
                    bias=(norm_layer is None),
                    dilation=dilation,
                    **kwargs,
                ),
            ],
            downsample,
            act_layer,
        )


class Bottleneck(_ResBlock):
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs: Any,
    ) -> None:
        width = int(planes * (base_width / 64.0)) * groups
        super().__init__(
            [
                *conv_sequence(
                    inplanes,
                    width,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=1,
                    stride=1,
                    bias=(norm_layer is None),
                    **kwargs,
                ),
                *conv_sequence(
                    width,
                    width,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=3,
                    stride=stride,
                    padding=dilation,
                    groups=groups,
                    bias=(norm_layer is None),
                    dilation=dilation,
                    **kwargs,
                ),
                *conv_sequence(
                    width,
                    planes * self.expansion,
                    None,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=1,
                    stride=1,
                    bias=(norm_layer is None),
                    **kwargs,
                ),
            ],
            downsample,
            act_layer,
        )


class ChannelRepeat(nn.Module):
    def __init__(self, chan_repeats: int = 1) -> None:
        super().__init__()
        self.chan_repeats = chan_repeats

    def forward(self, x: Tensor) -> Tensor:
        repeats = [1] * x.ndim
        # Repeat the tensor along the channel dimension
        repeats[1] = self.chan_repeats
        return x.repeat(*repeats)


class ResNet(nn.Sequential):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        num_blocks: List[int],
        planes: List[int],
        num_classes: int = 10,
        in_channels: int = 3,
        zero_init_residual: bool = False,
        width_per_group: int = 64,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        deep_stem: bool = False,
        stem_pool: bool = True,
        avg_downsample: bool = False,
        num_repeats: int = 1,
        block_args: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
    ) -> None:
        if conv_layer is None:
            conv_layer = nn.Conv2d
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if act_layer is None:
            act_layer = nn.ReLU(inplace=True)
        self.dilation = 1

        in_planes = 64
        # Deep stem from ResNet-C
        if deep_stem:
            layers = [
                *conv_sequence(
                    in_channels,
                    in_planes // 2,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    bias=(norm_layer is None),
                ),
                *conv_sequence(
                    in_planes // 2,
                    in_planes // 2,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=(norm_layer is None),
                ),
                *conv_sequence(
                    in_planes // 2,
                    in_planes,
                    act_layer,
                    norm_layer,
                    drop_layer,
                    conv_layer,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=(norm_layer is None),
                ),
            ]
        else:
            layers = conv_sequence(
                in_channels,
                in_planes,
                act_layer,
                norm_layer,
                drop_layer,
                conv_layer,
                kernel_size=7,
                stride=2,
                padding=3,
                bias=(norm_layer is None),
            )
        if stem_pool:
            layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        # Optional tensor repetitions along channel axis (mainly for TridentNet)
        if num_repeats > 1:
            layers.append(ChannelRepeat(num_repeats))

        # Consecutive convolutional blocks
        stride = 1
        # Block args
        if block_args is None:
            block_args = {"groups": 1}
        if not isinstance(block_args, list):
            block_args = [block_args] * len(num_blocks)
        for _num_blocks, _planes, _block_args in zip(num_blocks, planes, block_args):
            layers.append(
                self._make_layer(
                    block,
                    _num_blocks,
                    in_planes,
                    _planes,
                    stride,
                    width_per_group,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    drop_layer=drop_layer,
                    avg_downsample=avg_downsample,
                    num_repeats=num_repeats,
                    block_args=_block_args,
                )
            )
            in_planes = block.expansion * _planes
            stride = 2

        super().__init__(
            OrderedDict([
                ("features", nn.Sequential(*layers)),
                ("pool", GlobalAvgPool2d(flatten=True)),
                ("head", nn.Linear(num_repeats * in_planes, num_classes)),
            ])
        )

        # Init all layers
        init.init_module(self, nonlinearity="relu")

        # Init shortcut
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    m.convs[2][1].weight.data.zero_()
                elif isinstance(m, BasicBlock):
                    m.convs[1][1].weight.data.zero_()

    @staticmethod
    def _make_layer(
        block: Type[Union[BasicBlock, Bottleneck]],
        num_blocks: int,
        in_planes: int,
        planes: int,
        stride: int = 1,
        width_per_group: int = 64,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        avg_downsample: bool = False,
        num_repeats: int = 1,
        block_args: Optional[Dict[str, Any]] = None,
    ) -> nn.Sequential:
        downsample = None
        if stride != 1 or in_planes != planes * block.expansion:
            # Downsampling from ResNet-D
            if avg_downsample:
                downsample = nn.Sequential(
                    nn.AvgPool2d(stride, ceil_mode=True, count_include_pad=False),
                    *conv_sequence(
                        num_repeats * in_planes,
                        num_repeats * planes * block.expansion,
                        None,
                        norm_layer,
                        drop_layer,
                        conv_layer,
                        kernel_size=1,
                        stride=1,
                        bias=(norm_layer is None),
                    ),
                )
            else:
                downsample = nn.Sequential(
                    *conv_sequence(
                        num_repeats * in_planes,
                        num_repeats * planes * block.expansion,
                        None,
                        norm_layer,
                        drop_layer,
                        conv_layer,
                        kernel_size=1,
                        stride=stride,
                        bias=(norm_layer is None),
                    )
                )
        if block_args is None:
            block_args = {}
        layers = [
            block(
                in_planes,
                planes,
                stride,
                downsample,
                base_width=width_per_group,
                act_layer=act_layer,
                norm_layer=norm_layer,
                drop_layer=drop_layer,
                **block_args,
            )
        ]
        layers.extend([
            block(
                block.expansion * planes,
                planes,
                1,
                None,
                base_width=width_per_group,
                act_layer=act_layer,
                norm_layer=norm_layer,
                drop_layer=drop_layer,
                **block_args,
            )
            for _ in range(num_blocks - 1)
        ])

        return nn.Sequential(*layers)


def _resnet(
    checkpoint: Union[Checkpoint, None],
    progress: bool,
    block: Type[Union[BasicBlock, Bottleneck]],
    num_blocks: List[int],
    out_chans: List[int],
    **kwargs: Any,
) -> ResNet:
    # Build the model
    model = ResNet(block, num_blocks, out_chans, **kwargs)
    return _configure_model(model, checkpoint, progress=progress)


[docs] class ResNet18_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="resnet18", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/resnet18_224-fc07006c.pth", acc1=0.9361, acc5=0.9946, sha256="fc07006c894cac8cf380fed699bc5a68463698753c954632f52bb8595040f781", size=44787043, num_params=11181642, commit="6e32c5b578711a2ef3731a8f8c61760ed9f03e58", train_args=( "./imagenette2-320/ --arch resnet18 --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def resnet18( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNet-18 from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, loads that checkpoint progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ResNet18_Checkpoint :members: """ checkpoint = _handle_legacy_pretrained( pretrained, checkpoint, ResNet18_Checkpoint.DEFAULT.value, ) return _resnet(checkpoint, progress, BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512], **kwargs)
[docs] class ResNet34_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="resnet34", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/resnet34_224-412b0792.pth", acc1=0.9381, acc5=0.9949, sha256="412b07927cc1938ee3add8d0f6bb18b42786646182f674d75f1433d086914485", size=85267035, num_params=21289802, commit="6e32c5b578711a2ef3731a8f8c61760ed9f03e58", train_args=( "./imagenette2-320/ --arch resnet34 --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def resnet34( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNet-34 from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ResNet34_Checkpoint :members: """ return _resnet(checkpoint, progress, BasicBlock, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)
[docs] class ResNet50_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="resnet50", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/resnet50_224-5b913f0b.pth", acc1=0.9378, acc5=0.9954, sha256="5b913f0b8148b483ba15541ab600cf354ca42b326e4896c4c3dbc51eb1e80e70", size=94384682, num_params=23528522, commit="6e32c5b578711a2ef3731a8f8c61760ed9f03e58", train_args=( "./imagenette2-320/ --arch resnet50 --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def resnet50( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNet-50 from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ResNet50_Checkpoint :members: """ checkpoint = _handle_legacy_pretrained( pretrained, checkpoint, ResNet50_Checkpoint.DEFAULT.value, ) return _resnet(checkpoint, progress, Bottleneck, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)
[docs] class ResNet50D_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="resnet50d", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/resnet50d_224-6218d936.pth", acc1=0.9465, acc5=0.9952, sha256="6218d936fa67c0047f1ec65564213db538aa826d84f2df1d4fa3224531376e6c", size=94464810, num_params=23547754, commit="6e32c5b578711a2ef3731a8f8c61760ed9f03e58", train_args=( "./imagenette2-320/ --arch resnet50d --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def resnet50d( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNet-50-D from `"Bag of Tricks for Image Classification with Convolutional Neural Networks" <https://arxiv.org/pdf/1812.01187.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ResNet50D_Checkpoint :members: """ return _resnet( checkpoint, progress, Bottleneck, [3, 4, 6, 3], [64, 128, 256, 512], deep_stem=True, avg_downsample=True, **kwargs, )
[docs] def resnet101( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNet-101 from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model """ return _resnet(checkpoint, progress, Bottleneck, [3, 4, 23, 3], [64, 128, 256, 512], **kwargs)
[docs] def resnet152( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNet-152 from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model """ return _resnet(checkpoint, progress, Bottleneck, [3, 8, 86, 3], [64, 128, 256, 512], **kwargs)
[docs] class ResNeXt50_32x4d_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="resnext50_32x4d", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/resnext50_32x4d_224-5832c4ce.pth", acc1=0.9455, acc5=0.9949, sha256="5832c4ce33522a9eb7a8b5abe31cf30621721a92d4f99b4b332a007d81d071fe", size=92332638, num_params=23000394, commit="6e32c5b578711a2ef3731a8f8c61760ed9f03e58", train_args=( "./imagenette2-320/ --arch resnext50_32x4d --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def resnext50_32x4d( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNeXt-50 from `"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ResNeXt50_32x4d_Checkpoint :members: """ kwargs["width_per_group"] = 4 block_args = {"groups": 32} return _resnet( checkpoint, progress, Bottleneck, [3, 4, 6, 3], [64, 128, 256, 512], block_args=block_args, **kwargs, )
[docs] def resnext101_32x8d( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ResNet: """ResNeXt-101 from `"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNet checkpoint: If specified, load that checkpoint on the model progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _resnet Returns: torch.nn.Module: classification model """ kwargs["width_per_group"] = 8 block_args = {"groups": 32} return _resnet( checkpoint, progress, Bottleneck, [3, 4, 23, 3], [64, 128, 256, 512], block_args=block_args, **kwargs, )