Source code for holocron.models.classification.convnext

# Copyright (C) 2022-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from collections import OrderedDict
from enum import Enum
from functools import partial
from typing import Any, Callable, List, Optional, Union

import torch
import torch.nn as nn
from torch import Tensor
from torchvision.ops.stochastic_depth import StochasticDepth

from holocron.nn import GlobalAvgPool2d

from ..checkpoints import Checkpoint, _handle_legacy_pretrained
from ..utils import _checkpoint, _configure_model, conv_sequence
from .resnet import _ResBlock

__all__ = [
    "ConvNeXt",
    "ConvNeXt_Atto_Checkpoint",
    "convnext_atto",
    "convnext_base",
    "convnext_femto",
    "convnext_large",
    "convnext_nano",
    "convnext_pico",
    "convnext_small",
    "convnext_tiny",
    "convnext_xl",
]


class LayerNorm2d(nn.LayerNorm):
    """Compatibility wrapper of LayerNorm on 2D tensors"""

    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


class LayerScale(nn.Module):
    """Learnable channel-wise scaling"""

    def __init__(self, chans: int, scale: float = 1e-6) -> None:
        super().__init__()
        self.register_parameter("weight", nn.Parameter(scale * torch.ones(chans)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.weight.reshape(1, -1, *((1,) * (x.ndim - 2)))


class Bottlenext(_ResBlock):
    def __init__(
        self,
        inplanes: int,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        chan_expansion: int = 4,
        stochastic_depth_prob: float = 0.1,
        layer_scale: float = 1e-6,
    ) -> None:
        if norm_layer is None:
            norm_layer = partial(LayerNorm2d, eps=1e-6)
        if act_layer is None:
            act_layer = nn.GELU()

        super().__init__(
            [
                # Depth-conv (groups = in_channels): spatial awareness
                *conv_sequence(
                    inplanes,
                    inplanes,
                    None,
                    norm_layer,
                    drop_layer,
                    kernel_size=7,
                    padding=3,
                    stride=1,
                    bias=True,
                    groups=inplanes,
                ),
                # 1x1 conv: channel awareness
                *conv_sequence(
                    inplanes,
                    inplanes * chan_expansion,
                    act_layer,
                    None,
                    drop_layer,
                    kernel_size=1,
                    stride=1,
                    bias=True,
                ),
                # 1x1 conv: channel mapping
                *conv_sequence(
                    inplanes * chan_expansion,
                    inplanes,
                    None,
                    None,
                    drop_layer,
                    kernel_size=1,
                    stride=1,
                    bias=True,
                ),
                LayerScale(inplanes, layer_scale),
                StochasticDepth(stochastic_depth_prob, "row"),
            ],
            None,
            None,
        )


class ConvNeXt(nn.Sequential):
    def __init__(
        self,
        num_blocks: List[int],
        planes: List[int],
        num_classes: int = 10,
        in_channels: int = 3,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        stochastic_depth_prob: float = 0.0,
    ) -> None:
        if conv_layer is None:
            conv_layer = nn.Conv2d
        if norm_layer is None:
            norm_layer = partial(LayerNorm2d, eps=1e-6)
        if act_layer is None:
            act_layer = nn.GELU()
        self.dilation = 1

        # Patchify-like stem
        _layers = conv_sequence(
            in_channels,
            planes[0],
            None,
            norm_layer,
            drop_layer,
            conv_layer,
            kernel_size=4,
            stride=4,
            padding=0,
            bias=True,
        )

        block_idx = 0
        tot_blocks = sum(num_blocks)
        for _num_blocks, _planes, _oplanes in zip(num_blocks, planes, planes[1:] + [planes[-1]]):
            # adjust stochastic depth probability based on the depth of the stage block
            sd_probs = [stochastic_depth_prob * (block_idx + _idx) / (tot_blocks - 1.0) for _idx in range(_num_blocks)]
            _stage: List[nn.Module] = [
                Bottlenext(_planes, act_layer, norm_layer, drop_layer, stochastic_depth_prob=sd_prob)
                for _idx, sd_prob in zip(range(_num_blocks), sd_probs)
            ]
            if _planes != _oplanes:
                _stage.append(
                    nn.Sequential(
                        LayerNorm2d(_planes),
                        nn.Conv2d(_planes, _oplanes, kernel_size=2, stride=2),
                    )
                )
            _layers.append(nn.Sequential(*_stage))
            block_idx += _num_blocks

        super().__init__(
            OrderedDict([
                ("features", nn.Sequential(*_layers)),
                ("pool", GlobalAvgPool2d(flatten=True)),
                (
                    "head",
                    nn.Sequential(
                        nn.LayerNorm(planes[-1], eps=1e-6),
                        nn.Linear(planes[-1], num_classes),
                    ),
                ),
            ])
        )

        # Init all layers
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)


def _convnext(
    checkpoint: Union[Checkpoint, None],
    progress: bool,
    num_blocks: List[int],
    out_chans: List[int],
    **kwargs: Any,
) -> ConvNeXt:
    # Build the model
    model = ConvNeXt(num_blocks, out_chans, **kwargs)
    return _configure_model(model, checkpoint, progress=progress)


[docs] class ConvNeXt_Atto_Checkpoint(Enum): IMAGENETTE = _checkpoint( arch="convnext_atto", url="https://github.com/frgfm/Holocron/releases/download/v0.2.1/convnext_atto_224-f38217e7.pth", acc1=0.8759, acc5=0.9832, sha256="f38217e7361060e6fe00e8fa95b0e8774150190eed9e55c812bbd3b6ab378ce9", size=13535258, num_params=3377730, commit="d4a59999179b42fc0d3058ac6b76cc41f49dd56e", train_args=( "./imagenette2-320/ --arch convnext_atto --batch-size 64 --mixup-alpha 0.2 --amp --device 0 --epochs 100" " --lr 1e-3 --label-smoothing 0.1 --random-erase 0.1 --train-crop-size 176 --val-resize-size 232" " --opt adamw --weight-decay 5e-2" ), ) DEFAULT = IMAGENETTE
[docs] def convnext_atto( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any, ) -> ConvNeXt: """ConvNeXt-Atto variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained: If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress: If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model .. autoclass:: holocron.models.ConvNeXt_Atto_Checkpoint :members: """ checkpoint = _handle_legacy_pretrained( pretrained, checkpoint, ConvNeXt_Atto_Checkpoint.DEFAULT.value, ) return _convnext(checkpoint, progress, [2, 2, 6, 2], [40, 80, 160, 320], **kwargs)
[docs] def convnext_femto( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-Femto variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [2, 2, 6, 2], [48, 96, 192, 384], **kwargs)
[docs] def convnext_pico( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-Pico variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [2, 2, 6, 2], [64, 128, 256, 512], **kwargs)
[docs] def convnext_nano( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-Nano variant of Ross Wightman inspired by `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [2, 2, 8, 2], [80, 160, 320, 640], **kwargs)
[docs] def convnext_tiny( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-T from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [3, 3, 9, 3], [96, 192, 384, 768], **kwargs)
[docs] def convnext_small( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-S from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [3, 3, 27, 3], [96, 192, 384, 768], **kwargs)
[docs] def convnext_base( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-B from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [3, 3, 27, 3], [128, 256, 512, 1024], **kwargs)
[docs] def convnext_large( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-L from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [3, 3, 27, 3], [192, 384, 768, 1536], **kwargs)
[docs] def convnext_xl( pretrained: bool = False, checkpoint: Union[Checkpoint, None] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt-XL from `"A ConvNet for the 2020s" <https://arxiv.org/pdf/2201.03545.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNette checkpoint: If specified, the model's parameters will be set to the checkpoint's values progress (bool): If True, displays a progress bar of the download to stderr kwargs: keyword args of _convnext Returns: torch.nn.Module: classification model """ checkpoint = _handle_legacy_pretrained(pretrained, checkpoint, None) return _convnext(checkpoint, progress, [3, 3, 27, 3], [256, 512, 1024, 2048], **kwargs)