# -*- coding: utf-8 -*-
"""
Implementation of Res2Net
based on https://github.com/gasvn/Res2Net
"""
import warnings
import torch
import torch.nn as nn
from torchvision.models.resnet import conv1x1, conv3x3
from torchvision.models.utils import load_state_dict_from_url
RESNET_LAYERS = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3],
101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}
RES2NEXT_PARAMS = {50: dict(groups=8, width_per_group=4),
101: dict(groups=8, width_per_group=8)}
URLS = {
'res2net50_48w_2s': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2net50_48w_2s-afed724a.pth',
'res2net50_26w_4s': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2net50_26w_4s-06e79181.pth',
'res2net50_14w_8s': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2net50_14w_8s-6527dddc.pth',
'res2net50_26w_6s': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2net50_26w_6s-19041792.pth',
'res2net50_26w_8s': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2net50_26w_8s-2c7c9f12.pth',
'res2net101_26w_4s': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2net101_26w_4s-02a759a1.pth',
'res2next50_4w_4s_8c': 'http://mc.nankai.edu.cn/projects/res2net/pretrainmodels/res2next50_4s-6ef7e7bf.pth'
}
class Res2Block(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=4, dilation=1, scale=4, first_block=False, norm_layer=None):
"""Implements a residual block
Args:
inplanes (int): input channel dimensionality
planes (int): output channel dimensionality
stride (int): stride used for conv3x3
downsample (torch.nn.Module): module used for downsampling
groups: num of convolution groups
base_width: base width
dilation (int): dilation rate of conv3x3
scale (int): scaling ratio for cascade convs
first_block (bool): whether the block is the first to be placed in the conv layer
norm_layer (torch.nn.Module): norm layer to be used in blocks
"""
super(Res2Block, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = conv1x1(inplanes, width * scale)
self.bn1 = norm_layer(width * scale)
# If scale == 1, single conv else identity & (scale - 1) convs
nb_branches = max(scale, 2) - 1
if first_block:
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
self.convs = nn.ModuleList([conv3x3(width, width, stride, groups, dilation)
for _ in range(nb_branches)])
self.bns = nn.ModuleList([norm_layer(width) for _ in range(nb_branches)])
self.first_block = first_block
self.scale = scale
self.conv3 = conv1x1(width * scale, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=False)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# Chunk the feature map
xs = torch.chunk(out, self.scale, dim=1)
# Initialize output as empty tensor for proper concatenation
y = 0
for idx, conv in enumerate(self.convs):
# Add previous y-value
if self.first_block:
y = xs[idx]
else:
y += xs[idx]
y = conv(y)
y = self.relu(self.bns[idx](y))
# Concatenate with previously computed values
out = torch.cat((out, y), 1) if idx > 0 else y
# Use last chunk as x1
if self.scale > 1:
if self.first_block:
out = torch.cat((out, self.pool(xs[len(self.convs)])), 1)
else:
out = torch.cat((out, xs[len(self.convs)]), 1)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
[docs]
class Res2Net(nn.Module):
"""Implements a Res2Net model as described in https://arxiv.org/pdf/1904.01169.pdf
Args:
block (torch.nn.Module): class constructor to be used for residual blocks
layers (list<int>): layout of layers
num_classes (int): number of output classes
zero_init_residual (bool): whether the residual connections should be initialized at zero
groups (int): number of convolution groups
width_per_group (int): number of channels per group
scale (int): scaling ratio within blocks
replace_stride_with_dilation (list<bool>): whether stride should be traded for dilation
norm_layer (torch.nn.Module): norm layer to be used
"""
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=26, scale=4, replace_stride_with_dilation=None,
norm_layer=None):
super(Res2Net, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.scale = scale
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottle2neck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, self.scale, first_block=True, norm_layer=norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
scale=self.scale, first_block=False, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
[docs]
def res2net(depth, num_classes, width_per_group=26, scale=4, pretrained=False, progress=True, **kwargs):
"""Instantiate a Res2Net model
Args:
depth (int): depth of the model
num_classes (int): number of output classes
scale (int): number of branches for cascade convolutions
pretrained (bool): whether the model should load pretrained weights (ImageNet training)
progress (bool): whether a progress bar should be displayed while downloading pretrained weights
**kwargs: optional arguments of torchvision.models.resnet.ResNet
Returns:
model (torch.nn.Module): loaded Pytorch model
"""
if RESNET_LAYERS.get(depth) is None:
raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}")
block = Res2Block if depth >= 50 else BasicBlock
model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs)
if pretrained:
state_dict = None
try:
state_dict = load_state_dict_from_url(URLS.get(f"res2net{depth}_{width_per_group}w_{scale}s"),
map_location=torch.device('cpu'),
progress=progress)
except Exception as e:
warnings.warn(f"While downloading state_dict, received:\n{e}\nSkipping weight loading...")
if isinstance(state_dict, dict):
# Remove FC params from dict
for key in ('fc.weight', 'fc.bias'):
state_dict.pop(key, None)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\n"
f"Unexpected parameters: {unexpected}")
return model
[docs]
def res2next(depth, num_classes, width_per_group=4, scale=4, pretrained=False, progress=True, **kwargs):
"""Instantiate a Res2NeXt model
Args:
depth (int): depth of the model
num_classes (int): number of output classes
scale (int): number of branches for cascade convolutions
pretrained (bool): whether the model should load pretrained weights (ImageNet training)
progress (bool): whether a progress bar should be displayed while downloading pretrained weights
**kwargs: optional arguments of torchvision.models.resnet.ResNet
Returns:
model (torch.nn.Module): loaded Pytorch model
"""
if RESNET_LAYERS.get(depth) is None:
raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}")
block = Res2Block if depth >= 50 else BasicBlock
kwargs.update(RES2NEXT_PARAMS.get(depth))
model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs)
if pretrained:
state_dict = None
try:
model_name = f"res2next{depth}_{width_per_group}w_{scale}s_{kwargs['groups']}c"
state_dict = load_state_dict_from_url(URLS.get(model_name),
map_location=torch.device('cpu'),
progress=progress)
except Exception as e:
warnings.warn(f"While downloading state_dict, received:\n{e}\nSkipping weight loading...")
if isinstance(state_dict, dict):
# Remove FC params from dict
for key in ('fc.weight', 'fc.bias'):
state_dict.pop(key, None)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\n"
f"Unexpected parameters: {unexpected}")
return model