Source code for holocron.models.classification.repvgg

# Copyright (C) 2021-2022, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn

from holocron.nn import GlobalAvgPool2d

from ..presets import IMAGENETTE
from ..utils import fuse_conv_bn, load_pretrained_params

__all__ = [
    "RepVGG",
    "RepBlock",
    "RepVGG",
    "repvgg_a0",
    "repvgg_a1",
    "repvgg_a2",
    "repvgg_b0",
    "repvgg_b1",
    "repvgg_b2",
    "repvgg_b3",
]

default_cfgs: Dict[str, Dict[str, Any]] = {
    "repvgg_a0": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_a0_224-150f4b9d.pt",
    },
    "repvgg_a1": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_a1_224-870b9e4b.pt",
    },
    "repvgg_a2": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_a2_224-7051289a.pt",
    },
    "repvgg_b0": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/repvgg_b0_224-7e9c3fc7.pth",
    },
    "repvgg_b1": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": None,
    },
    "repvgg_b2": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": None,
    },
    "repvgg_b3": {
        **IMAGENETTE,
        "input_shape": (3, 224, 224),
        "url": None,
    },
}


class RepBlock(nn.Module):
    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        identity: bool = True,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
    ) -> None:
        super().__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if act_layer is None:
            act_layer = nn.ReLU(inplace=True)

        self.branches: Union[nn.Conv2d, nn.ModuleList] = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(inplanes, planes, 3, padding=1, bias=(norm_layer is None), stride=stride),
                    norm_layer(planes),
                ),
                nn.Sequential(
                    nn.Conv2d(inplanes, planes, 1, padding=0, bias=(norm_layer is None), stride=stride),
                    norm_layer(planes),
                ),
            ]
        )

        self.activation = act_layer

        if identity:
            if inplanes != planes:
                raise ValueError("The number of input and output channels must be identical if identity is used")
            self.branches.append(nn.BatchNorm2d(planes))

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        if isinstance(self.branches, nn.Conv2d):
            out = self.branches(x)
        else:
            out = sum(branch(x) for branch in self.branches)

        return self.activation(out)

    def reparametrize(self) -> None:
        """Reparametrize the block by fusing convolutions and BN in each branch, then fusing all branches"""
        if not isinstance(self.branches, nn.ModuleList):
            raise AssertionError
        inplanes = self.branches[0][0].weight.data.shape[1]
        planes = self.branches[0][0].weight.data.shape[0]
        # Instantiate the equivalent Conv 3x3
        rep = nn.Conv2d(inplanes, planes, 3, padding=1, bias=True, stride=self.branches[0][0].stride)

        # Fuse convolutions with their BN
        fused_k3, fused_b3 = fuse_conv_bn(*self.branches[0])
        fused_k1, fused_b1 = fuse_conv_bn(*self.branches[1])

        # Conv 3x3
        rep.weight.data = fused_k3
        rep.bias.data = fused_b3  # type: ignore[union-attr]

        # Conv 1x1
        rep.weight.data[..., 1:2, 1:2] += fused_k1
        rep.bias.data += fused_b1  # type: ignore[union-attr]

        # Identity
        if len(self.branches) == 3:
            scale_factor = self.branches[2].weight.data / (self.branches[2].running_var + self.branches[2].eps).sqrt()
            # Identity is mapped as a diagonal matrix relatively to the out/in channel dimensions
            rep.weight.data[range(planes), range(inplanes), 1, 1] += scale_factor
            rep.bias.data += self.branches[2].bias.data  # type: ignore[union-attr]
            rep.bias.data -= scale_factor * self.branches[2].running_mean  # type: ignore[union-attr]

        # Update main branch & delete the others
        self.branches = rep


class RepVGG(nn.Sequential):
    """Implements a reparametrized version of VGG as described in
    `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_

    Args:
        num_blocks: list of number of blocks per stage
        planes: list of output channels of each stage
        width_multiplier: multiplier for the output channels of all stages apart from the last
        final_width_multiplier: multiplier for the output channels of the last stage
        num_classes: number of output classes
        in_channels: number of input channels
        act_layer: the activation layer to use
        norm_layer: the normalization layer to use
    """

    def __init__(
        self,
        num_blocks: List[int],
        planes: List[int],
        width_multiplier: float,
        final_width_multiplier: float,
        num_classes: int = 10,
        in_channels: int = 3,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
    ) -> None:

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if act_layer is None:
            act_layer = nn.ReLU(inplace=True)

        if len(num_blocks) != len(planes):
            raise AssertionError("the length of `num_blocks` and `planes` are expected to be the same")

        _stages: List[nn.Sequential] = []
        # Assign the width multipliers
        chans = [in_channels, int(min(1, width_multiplier) * planes[0])]
        chans.extend([int(width_multiplier * chan) for chan in planes[1:-1]])
        chans.append(int(final_width_multiplier * planes[-1]))

        # Build the layers
        for nb_blocks, in_chan, out_chan in zip(num_blocks, chans[:-1], chans[1:]):
            _layers = [RepBlock(in_chan, out_chan, 2, False, act_layer, norm_layer)]
            _layers.extend([RepBlock(out_chan, out_chan, 1, True, act_layer, norm_layer) for _ in range(nb_blocks)])
            _stages.append(nn.Sequential(*_layers))

        super().__init__(
            OrderedDict(
                [
                    ("features", nn.Sequential(*_stages)),
                    ("pool", GlobalAvgPool2d(flatten=True)),
                    ("head", nn.Linear(chans[-1], num_classes)),
                ]
            )
        )

    def reparametrize(self) -> None:
        """Reparametrize the block by fusing convolutions and BN in each branch, then fusing all branches"""
        self.features: nn.Sequential
        for stage in self.features:
            for block in stage:
                block.reparametrize()


def _repvgg(
    arch: str,
    pretrained: bool,
    progress: bool,
    num_blocks: List[int],
    out_chans: List[int],
    a: float,
    b: float,
    **kwargs: Any,
) -> RepVGG:

    # Build the model
    model = RepVGG(num_blocks, [64, 64, 128, 256, 512], a, b, **kwargs)
    model.default_cfg = default_cfgs[arch]  # type: ignore[assignment]
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, default_cfgs[arch]["url"], progress)

    return model


[docs] def repvgg_a0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-A0 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_a0", pretrained, progress, [1, 2, 4, 14, 1], [64, 64, 128, 256, 512], 0.75, 2.5, **kwargs)
[docs] def repvgg_a1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-A1 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_a1", pretrained, progress, [1, 2, 4, 14, 1], [64, 64, 128, 256, 512], 1, 2.5, **kwargs)
[docs] def repvgg_a2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-A2 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_a2", pretrained, progress, [1, 2, 4, 14, 1], [64, 64, 128, 256, 512], 1.5, 2.75, **kwargs)
[docs] def repvgg_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-B0 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_b0", pretrained, progress, [1, 4, 6, 16, 1], [64, 64, 128, 256, 512], 1, 2.5, **kwargs)
[docs] def repvgg_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-B1 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_b1", pretrained, progress, [1, 4, 6, 16, 1], [64, 64, 128, 256, 512], 2, 4, **kwargs)
[docs] def repvgg_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-B2 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_b2", pretrained, progress, [1, 4, 6, 16, 1], [64, 64, 128, 256, 512], 2.5, 5, **kwargs)
[docs] def repvgg_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: """RepVGG-B3 from `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/pdf/2101.03697.pdf>`_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr Returns: torch.nn.Module: classification model """ return _repvgg("repvgg_b3", pretrained, progress, [1, 4, 6, 16, 1], [64, 64, 128, 256, 512], 3, 5, **kwargs)