# -*- coding: utf-8 -*-
'''
Functional interface
'''
import torch
import torch.nn.functional as F
__all__ = ['mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy']
[docs]
def mish(x):
"""Implements the Mish activation function
Args:
x (torch.Tensor): input tensor
Returns:
torch.Tensor[x.size()]: output tensor
"""
return x * torch.tanh(F.softplus(x))
[docs]
def nl_relu(x, beta=1., inplace=False):
"""Implements the natural logarithm ReLU activation function
Args:
x (torch.Tensor): input tensor
beta (float): beta used for NReLU
inplace (bool): whether the operation should be performed inplace
Returns:
torch.Tensor[x.size()]: output tensor
"""
if inplace:
return torch.log(F.relu_(x).mul_(beta).add_(1), out=x)
else:
return torch.log(1 + beta * F.relu(x))
[docs]
def focal_loss(x, target, weight=None, ignore_index=-100, reduction='mean', gamma=2):
"""Implements the focal loss from
`"Focal Loss for Dense Object Detection" <https://arxiv.org/pdf/1708.02002.pdf>`_
Args:
x (torch.Tensor[N, K, ...]): input tensor
target (torch.Tensor[N, ...]): hard target tensor
weight (torch.Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): reduction method
gamma (float, optional): gamma parameter of focal loss
Returns:
torch.Tensor: loss reduced with `reduction` method
"""
# log(P[class]) = log_softmax(score)[class]
logpt = F.log_softmax(x, dim=1)
# Compute pt and logpt only for target classes (the remaining will have a 0 coefficient)
logpt = logpt.transpose(1, 0).flatten(1).gather(0, target.view(1, -1)).squeeze()
# Ignore index (set loss contribution to 0)
if ignore_index >= 0:
logpt[target.view(-1) == ignore_index] = 0
# Get P(class)
pt = logpt.exp()
# Weight
if weight is not None:
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
at = weight.gather(0, target.data.view(-1))
logpt *= at
# Loss
loss = -1 * (1 - pt) ** gamma * logpt
# Loss reduction
if reduction == 'sum':
loss = loss.sum()
elif reduction == 'mean':
# Ignore contribution to the loss if target is `ignore_index`
if ignore_index >= 0:
loss = loss[target.view(-1) != ignore_index]
loss = loss.mean()
else:
# if no reduction, reshape tensor like target
loss = loss.view(*target.shape)
return loss
[docs]
def concat_downsample2d(x, scale_factor):
"""Implements a loss-less downsampling operation described in
`"YOLO9000: Better, Faster, Stronger" <https://pjreddie.com/media/files/papers/YOLO9000.pdf>`_
by stacking adjacent information on the channel dimension.
Args:
x (torch.Tensor[N, C, H, W]): input tensor
scale_factor (int): spatial scaling factor
Returns:
torch.Tensor[N, 4C, H / 2, W / 2]: downsampled tensor
"""
b, c, h, w = x.shape
if (h % scale_factor != 0) or (w % scale_factor != 0):
raise AssertionError("Spatial size of input tensor must be multiples of `scale_factor`")
new_h, new_w = h // scale_factor, w // scale_factor
# N * C * H * W --> N * C * (H/scale_factor) * scale_factor * (W/scale_factor) * scale_factor
out = x.view(b, c, new_h, scale_factor, new_w, scale_factor)
# Move extra axes to last position to flatten them with channel dimension
out = out.permute(0, 2, 4, 1, 3, 5).flatten(3)
# Reorder all axes
out = out.permute(0, 3, 1, 2)
return out
[docs]
def multilabel_cross_entropy(x, target, weight=None, ignore_index=-100, reduction='mean'):
"""Implements the cross entropy loss for multi-label targets
Args:
x (torch.Tensor[N, K, ...]): input tensor
target (torch.Tensor[N, K, ...]): target tensor
weight (torch.Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): reduction method
Returns:
torch.Tensor: loss reduced with `reduction` method
"""
# log(P[class]) = log_softmax(score)[class]
logpt = F.log_softmax(x, dim=1)
# Ignore index (set loss contribution to 0)
if ignore_index >= 0 and ignore_index < x.shape[1]:
logpt[:, ignore_index] = 0
# Weight
if weight is not None:
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
logpt *= weight.view(1, -1)
# CE Loss
loss = - target * logpt
# Loss reduction
if reduction == 'sum':
loss = loss.sum()
else:
loss = loss.sum(dim=1)
if reduction == 'mean':
loss = loss.mean()
return loss
[docs]
def ls_cross_entropy(x, target, weight=None, ignore_index=-100, reduction='mean', eps=0.1):
"""Implements the label smoothing cross entropy loss from
`"Attention Is All You Need" <https://arxiv.org/pdf/1706.03762.pdf>`_
Args:
x (torch.Tensor[N, K, ...]): input tensor
target (torch.Tensor[N, ...]): target tensor
weight (torch.Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): reduction method
eps (float, optional): smoothing factor
Returns:
torch.Tensor: loss reduced with `reduction` method
"""
if eps == 0:
return F.cross_entropy(x, target, weight, ignore_index=ignore_index, reduction=reduction)
# log(P[class]) = log_softmax(score)[class]
logpt = F.log_softmax(x, dim=1)
# Ignore index (set loss contribution to 0)
if ignore_index >= 0 and ignore_index < x.shape[1]:
logpt[:, ignore_index] = 0
# Weight
if weight is not None:
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
logpt *= weight.view(1, -1)
# Loss reduction
if reduction == 'sum':
loss = -logpt.sum()
else:
loss = -logpt.sum(dim=1)
if reduction == 'mean':
loss = loss.mean()
# Smooth the labels
return eps / x.shape[1] * loss + (1 - eps) * F.nll_loss(logpt, target, weight,
ignore_index=ignore_index, reduction=reduction)