# Copyright (C) 2020-2022, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch.nn as nn
from torch import Tensor
from holocron.nn import GlobalAvgPool2d, init
from ..presets import IMAGENET, IMAGENETTE
from ..utils import conv_sequence, load_pretrained_params
__all__ = [
"BasicBlock",
"Bottleneck",
"ResNet",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"resnet50d",
]
default_cfgs: Dict[str, Dict[str, Any]] = {
"resnet18": {**IMAGENET, "input_shape": (3, 224, 224), "url": None},
"resnet34": {**IMAGENET, "input_shape": (3, 224, 224), "url": None},
"resnet50": {
**IMAGENETTE,
"input_shape": (3, 256, 256),
"url": "https://github.com/frgfm/Holocron/releases/download/v0.1.2/resnet50_256-5e6206e0.pth",
},
"resnet101": {**IMAGENET, "input_shape": (3, 224, 224), "url": None},
"resnet152": {**IMAGENET, "input_shape": (3, 224, 224), "url": None},
"resnext50_32x4d": {**IMAGENET, "input_shape": (3, 224, 224), "url": None},
"resnext101_32x8d": {**IMAGENET, "input_shape": (3, 224, 224), "url": None},
"resnet50d": {
**IMAGENETTE,
"input_shape": (3, 224, 224),
"url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/resnet50d_224-e315ba9d.pt",
},
}
class _ResBlock(nn.Module):
expansion: int = 1
def __init__(
self, convs: List[nn.Module], downsample: Optional[nn.Module] = None, act_layer: Optional[nn.Module] = None
) -> None:
super().__init__()
# Main branch
self.conv = nn.Sequential(*convs)
# Shortcut connection
self.downsample = downsample
if isinstance(act_layer, nn.Module):
self.activation = act_layer
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv(x)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
if hasattr(self, "activation"):
out = self.activation(out)
return out
class BasicBlock(_ResBlock):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
act_layer: Optional[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
drop_layer: Optional[Callable[..., nn.Module]] = None,
conv_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
super().__init__(
[
*conv_sequence(
inplanes,
planes,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=(norm_layer is None),
dilation=dilation,
**kwargs,
),
*conv_sequence(
planes,
planes,
None,
norm_layer,
drop_layer,
conv_layer,
kernel_size=3,
stride=1,
padding=dilation,
groups=groups,
bias=(norm_layer is None),
dilation=dilation,
**kwargs,
),
],
downsample,
act_layer,
)
class Bottleneck(_ResBlock):
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
act_layer: Optional[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
drop_layer: Optional[Callable[..., nn.Module]] = None,
conv_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
width = int(planes * (base_width / 64.0)) * groups
super().__init__(
[
*conv_sequence(
inplanes,
width,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=1,
stride=1,
bias=(norm_layer is None),
**kwargs,
),
*conv_sequence(
width,
width,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=(norm_layer is None),
dilation=dilation,
**kwargs,
),
*conv_sequence(
width,
planes * self.expansion,
None,
norm_layer,
drop_layer,
conv_layer,
kernel_size=1,
stride=1,
bias=(norm_layer is None),
**kwargs,
),
],
downsample,
act_layer,
)
class ChannelRepeat(nn.Module):
def __init__(self, chan_repeats: int = 1) -> None:
super().__init__()
self.chan_repeats = chan_repeats
def forward(self, x: Tensor) -> Tensor:
repeats = [1] * x.ndim
# Repeat the tensor along the channel dimension
repeats[1] = self.chan_repeats
return x.repeat(*repeats)
class ResNet(nn.Sequential):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
num_blocks: List[int],
planes: List[int],
num_classes: int = 10,
in_channels: int = 3,
zero_init_residual: bool = False,
width_per_group: int = 64,
conv_layer: Optional[Callable[..., nn.Module]] = None,
act_layer: Optional[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
drop_layer: Optional[Callable[..., nn.Module]] = None,
deep_stem: bool = False,
stem_pool: bool = True,
avg_downsample: bool = False,
num_repeats: int = 1,
block_args: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
) -> None:
if conv_layer is None:
conv_layer = nn.Conv2d
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if act_layer is None:
act_layer = nn.ReLU(inplace=True)
self.dilation = 1
in_planes = 64
# Deep stem from ResNet-C
if deep_stem:
_layers = [
*conv_sequence(
in_channels,
in_planes // 2,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=3,
stride=2,
padding=1,
bias=(norm_layer is None),
),
*conv_sequence(
in_planes // 2,
in_planes // 2,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=3,
stride=1,
padding=1,
bias=(norm_layer is None),
),
*conv_sequence(
in_planes // 2,
in_planes,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=3,
stride=1,
padding=1,
bias=(norm_layer is None),
),
]
else:
_layers = conv_sequence(
in_channels,
in_planes,
act_layer,
norm_layer,
drop_layer,
conv_layer,
kernel_size=7,
stride=2,
padding=3,
bias=(norm_layer is None),
)
if stem_pool:
_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# Optional tensor repetitions along channel axis (mainly for TridentNet)
if num_repeats > 1:
_layers.append(ChannelRepeat(num_repeats))
# Consecutive convolutional blocks
stride = 1
# Block args
if block_args is None:
block_args = dict(groups=1)
if not isinstance(block_args, list):
block_args = [block_args] * len(num_blocks)
for _num_blocks, _planes, _block_args in zip(num_blocks, planes, block_args):
_layers.append(
self._make_layer(
block,
_num_blocks,
in_planes,
_planes,
stride,
width_per_group,
act_layer=act_layer,
norm_layer=norm_layer,
drop_layer=drop_layer,
avg_downsample=avg_downsample,
num_repeats=num_repeats,
block_args=_block_args,
)
)
in_planes = block.expansion * _planes
stride = 2
super().__init__(
OrderedDict(
[
("features", nn.Sequential(*_layers)),
("pool", GlobalAvgPool2d(flatten=True)),
("head", nn.Linear(num_repeats * in_planes, num_classes)),
]
)
)
# Init all layers
init.init_module(self, nonlinearity="relu")
# Init shortcut
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
m.convs[2][1].weight.data.zero_() # type: ignore[index, union-attr]
elif isinstance(m, BasicBlock):
m.convs[1][1].weight.data.zero_() # type: ignore[index, union-attr]
@staticmethod
def _make_layer(
block: Type[Union[BasicBlock, Bottleneck]],
num_blocks: int,
in_planes: int,
planes: int,
stride: int = 1,
width_per_group: int = 64,
act_layer: Optional[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
drop_layer: Optional[Callable[..., nn.Module]] = None,
conv_layer: Optional[Callable[..., nn.Module]] = None,
avg_downsample: bool = False,
num_repeats: int = 1,
block_args: Optional[Dict[str, Any]] = None,
) -> nn.Sequential:
downsample = None
if stride != 1 or in_planes != planes * block.expansion:
# Downsampling from ResNet-D
if avg_downsample:
downsample = nn.Sequential(
nn.AvgPool2d(stride, ceil_mode=True, count_include_pad=False),
*conv_sequence(
num_repeats * in_planes,
num_repeats * planes * block.expansion,
None,
norm_layer,
drop_layer,
conv_layer,
kernel_size=1,
stride=1,
bias=(norm_layer is None),
),
)
else:
downsample = nn.Sequential(
*conv_sequence(
num_repeats * in_planes,
num_repeats * planes * block.expansion,
None,
norm_layer,
drop_layer,
conv_layer,
kernel_size=1,
stride=stride,
bias=(norm_layer is None),
)
)
if block_args is None:
block_args = {}
layers = [
block(
in_planes,
planes,
stride,
downsample,
base_width=width_per_group,
act_layer=act_layer,
norm_layer=norm_layer,
drop_layer=drop_layer,
**block_args,
)
]
for _ in range(num_blocks - 1):
layers.append(
block(
block.expansion * planes,
planes,
1,
None,
base_width=width_per_group,
act_layer=act_layer,
norm_layer=norm_layer,
drop_layer=drop_layer,
**block_args,
)
)
return nn.Sequential(*layers)
def _resnet(
arch: str,
pretrained: bool,
progress: bool,
block: Type[Union[BasicBlock, Bottleneck]],
num_blocks: List[int],
out_chans: List[int],
**kwargs: Any,
) -> ResNet:
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
# Build the model
model = ResNet(block, num_blocks, out_chans, **kwargs)
model.default_cfg = default_cfgs[arch] # type: ignore[assignment]
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"], progress)
return model
[docs]
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-18 from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
return _resnet("resnet18", pretrained, progress, BasicBlock, [2, 2, 2, 2], [64, 128, 256, 512], **kwargs)
[docs]
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-34 from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
return _resnet("resnet34", pretrained, progress, BasicBlock, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)
[docs]
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-50 from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
return _resnet("resnet50", pretrained, progress, Bottleneck, [3, 4, 6, 3], [64, 128, 256, 512], **kwargs)
[docs]
def resnet50d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-50-D from
`"Bag of Tricks for Image Classification with Convolutional Neural Networks"
<https://arxiv.org/pdf/1812.01187.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
return _resnet(
"resnet50d",
pretrained,
progress,
Bottleneck,
[3, 4, 6, 3],
[64, 128, 256, 512],
deep_stem=True,
avg_downsample=True,
**kwargs,
)
[docs]
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-101 from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
return _resnet("resnet101", pretrained, progress, Bottleneck, [3, 4, 23, 3], [64, 128, 256, 512], **kwargs)
[docs]
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-152 from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
return _resnet("resnet152", pretrained, progress, Bottleneck, [3, 8, 86, 3], [64, 128, 256, 512], **kwargs)
[docs]
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNeXt-50 from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
kwargs["width_per_group"] = 4
block_args = dict(groups=32)
return _resnet(
"resnext50_32x4d",
pretrained,
progress,
Bottleneck,
[3, 4, 6, 3],
[64, 128, 256, 512],
block_args=block_args,
**kwargs,
)
[docs]
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNeXt-101 from
`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: classification model
"""
kwargs["width_per_group"] = 8
block_args = dict(groups=32)
return _resnet(
"resnext101_32x8d",
pretrained,
progress,
Bottleneck,
[3, 4, 23, 3],
[64, 128, 256, 512],
block_args=block_args,
**kwargs,
)