Source code for holocron.models.detection.yolov4

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops.boxes import box_iou, nms
from torchvision.ops.misc import FrozenBatchNorm2d
from typing import Dict, Any, Optional, Tuple, List, Union, Callable

from ..utils import conv_sequence, load_pretrained_params
from ..darknetv4 import DarknetBodyV4, default_cfgs as dark_cfgs
from holocron.ops.boxes import ciou_loss
from holocron.nn import Mish, DropBlock2d, SPP, SAM
from holocron.nn.init import init_module


__all__ = ['YOLOv4', 'yolov4', 'SPP', 'PAN', 'Neck']


default_cfgs = {
    'yolov4': {'arch': 'YOLOv4', 'backbone': dark_cfgs['cspdarknet53'],
               'url': None},
}


class PAN(nn.Module):
    """PAN layer from `"Path Aggregation Network for Instance Segmentation" <https://arxiv.org/pdf/1803.01534.pdf>`_.

    Args:
        in_channels (int): input channels
        act_layer (torch.nn.Module, optional): activation layer to be used
        norm_layer (callable, optional): normalization layer
        drop_layer (callable, optional): regularization layer
        conv_layer (callable, optional): convolutional layer
    """
    def __init__(
        self,
        in_channels: int,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(*conv_sequence(in_channels, in_channels // 2,
                                                  act_layer, norm_layer, drop_layer, conv_layer,
                                                  kernel_size=1, bias=False))
        self.up = nn.Upsample(scale_factor=2, mode='nearest')

        self.conv2 = nn.Sequential(*conv_sequence(in_channels, in_channels // 2,
                                                  act_layer, norm_layer, drop_layer, conv_layer,
                                                  kernel_size=1, bias=False))

        self.convs = nn.Sequential(
            *conv_sequence(in_channels, in_channels // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(in_channels // 2, in_channels, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(in_channels, in_channels // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(in_channels // 2, in_channels, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(in_channels, in_channels // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False))

    def forward(self, x: Tensor, up: Tensor) -> Tensor:
        out = self.conv1(x)

        out = torch.cat([self.conv2(up), self.up(out)], dim=1)

        return self.convs(out)


class Neck(nn.Module):
    def __init__(
        self,
        in_planes: List[int],
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__()

        self.fpn = nn.Sequential(
            *conv_sequence(in_planes[0], in_planes[0] // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(in_planes[0] // 2, in_planes[0], act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(in_planes[0], in_planes[0] // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            SPP([5, 9, 13]),
            *conv_sequence(4 * in_planes[0] // 2, in_planes[0] // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(in_planes[0] // 2, in_planes[0], act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(in_planes[0], in_planes[0] // 2, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False)
        )

        self.pan1 = PAN(in_planes[1], act_layer, norm_layer, drop_layer, conv_layer)
        self.pan2 = PAN(in_planes[2], act_layer, norm_layer, drop_layer, conv_layer)
        init_module(self, 'leaky_relu')

    def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:

        out = self.fpn(feats[2])

        aux1 = self.pan1(out, feats[1])
        aux2 = self.pan2(aux1, feats[0])

        return aux2, aux1, out


class YoloLayer(nn.Module):
    """Scale-specific part of YoloHead"""
    def __init__(
        self,
        anchors: Tensor,
        num_classes: int = 80,
        scale_xy: float = 1.,
        iou_thresh: float = 0.213,
        lambda_noobj: float = 0.5,
        lambda_coords: float = 5.,
        rpn_nms_thresh: float = 0.7,
        box_score_thresh: float = 0.05
    ) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.register_buffer('anchors', anchors)

        self.rpn_nms_thresh = rpn_nms_thresh
        self.box_score_thresh = box_score_thresh
        self.lambda_noobj = lambda_noobj
        self.lambda_coords = lambda_coords

        # cf. https://github.com/AlexeyAB/darknet/blob/master/cfg/yolov4.cfg#L1150
        self.scale_xy = scale_xy
        # cf. https://github.com/AlexeyAB/darknet/blob/master/cfg/yolov4.cfg#L1151
        self.iou_thresh = iou_thresh

    def extra_repr(self) -> str:
        return f"num_classes={self.num_classes}, scale_xy={self.scale_xy}"

    def _format_outputs(self, output: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        b, _, h, w = output.shape

        self.anchors: Tensor
        # B x (num_anchors * (5 + num_classes)) x H x W --> B x H x W x num_anchors x (5 + num_classes)
        output = output.reshape(b, len(self.anchors), 5 + self.num_classes, h, w).permute(0, 3, 4, 1, 2)

        # Box center
        c_x = torch.arange(w, dtype=torch.float32, device=output.device).view(1, 1, -1, 1)
        c_y = torch.arange(h, dtype=torch.float32, device=output.device).view(1, -1, 1, 1)

        b_xy = self.scale_xy * torch.sigmoid(output[..., :2]) - 0.5 * (self.scale_xy - 1)
        b_xy[..., 0].add_(c_x)
        b_xy[..., 1].add_(c_y)
        b_xy[..., 0].div_(w)
        b_xy[..., 1].div_(h)

        # Box dimension
        b_wh = torch.exp(output[..., 2:4]) * self.anchors.view(1, 1, 1, -1, 2)

        top_left = b_xy - 0.5 * b_wh
        bot_right = top_left + b_wh
        boxes = torch.cat((top_left, bot_right), dim=-1)

        # Objectness
        b_o = torch.sigmoid(output[..., 4])
        # Classification scores
        b_scores = torch.sigmoid(output[..., 5:])

        return boxes, b_o, b_scores

    @staticmethod
    def post_process(
        boxes: Tensor,
        b_o: Tensor,
        b_scores: Tensor,
        rpn_nms_thresh: float = 0.7,
        box_score_thresh: float = 0.05
    ) -> List[Dict[str, Tensor]]:

        boxes = boxes.clamp_(0, 1)
        detections = []
        for idx in range(b_o.shape[0]):

            coords = torch.zeros((0, 4), dtype=torch.float32, device=b_o.device)
            scores = torch.zeros(0, dtype=torch.float32, device=b_o.device)
            labels = torch.zeros(0, dtype=torch.long, device=b_o.device)

            # Objectness filter
            if torch.any(b_o[idx] >= 0.5):
                coords = boxes[idx, b_o[idx] >= 0.5]
                scores, labels = b_scores[idx, b_o[idx] >= 0.5].max(dim=-1)
                # Multiply by the objectness
                scores.mul_(b_o[idx, b_o[idx] >= 0.5])

                # Confidence threshold
                coords = coords[scores >= box_score_thresh]
                labels = labels[scores >= box_score_thresh]
                scores = scores[scores >= box_score_thresh]
                coords = coords.clamp_(0, 1)
                # NMS
                kept_idxs = nms(coords, scores, iou_threshold=rpn_nms_thresh)
                coords = coords[kept_idxs]
                scores = scores[kept_idxs]
                labels = labels[kept_idxs]

            detections.append(dict(boxes=coords, scores=scores, labels=labels))

        return detections

    def _build_targets(
        self,
        pred_boxes: Tensor,
        b_o: Tensor,
        b_scores: Tensor,
        target: List[Dict[str, Tensor]]
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:

        b, h, w, num_anchors = b_o.shape

        # Target formatting
        target_o = torch.zeros((b, h, w, num_anchors), device=b_o.device)
        target_scores = torch.zeros((b, h, w, num_anchors, self.num_classes), device=b_o.device)
        obj_mask = torch.zeros((b, h, w, num_anchors), dtype=torch.bool, device=b_o.device)
        noobj_mask = torch.ones((b, h, w, num_anchors), dtype=torch.bool, device=b_o.device)

        gt_boxes = [t['boxes'] for t in target]
        gt_labels = [t['labels'] for t in target]

        # GT coords --> left, top, width, height
        _boxes = torch.cat(gt_boxes, dim=0)
        gt_centers = _boxes[..., [0, 2, 1, 3]].view(-1, 2, 2).mean(dim=-1)
        gt_centers[:, 0] *= w
        gt_centers[:, 1] *= h
        gt_centers = gt_centers.to(dtype=torch.long)

        target_selection = torch.tensor([_idx for _idx, _boxes in enumerate(gt_boxes) for _ in range(_boxes.shape[0])],
                                        dtype=torch.long, device=b_o.device)
        if target_selection.shape[0] > 0:

            # Anchors IoU
            gt_wh = _boxes[:, 2:] - _boxes[:, :2]
            anchor_idxs = box_iou(torch.cat((-gt_wh, gt_wh), dim=-1),
                                  torch.cat((-self.anchors, self.anchors), dim=-1)).argmax(dim=1)

            # Assign boxes
            obj_mask[target_selection, gt_centers[:, 1], gt_centers[:, 0], anchor_idxs] = True
            noobj_mask[target_selection, gt_centers[:, 1], gt_centers[:, 0], anchor_idxs] = False
            # B * cells * predictors * info
            for idx in range(b):
                if gt_boxes[idx].shape[0] > 0:
                    # IoU with cells that enclose the GT centers
                    gt_ious, gt_idxs = box_iou(pred_boxes[idx, obj_mask[idx]], gt_boxes[idx]).max(dim=1)
                    # Objectness target
                    target_o[idx, obj_mask[idx]] = gt_ious
                    # Classification target
                    target_scores[idx, obj_mask[idx], gt_labels[idx][gt_idxs]] = 1.

        return target_o, target_scores, obj_mask, noobj_mask

    def _compute_losses(
        self,
        pred_boxes: Tensor,
        b_o: Tensor,
        b_scores: Tensor,
        target: List[Dict[str, Tensor]],
        ignore_high_iou: bool = False
    ) -> Dict[str, Tensor]:

        target_o, target_scores, obj_mask, noobj_mask = self._build_targets(pred_boxes, b_o, b_scores, target)

        # Bbox regression
        bbox_loss = torch.zeros(1, device=b_o.device)
        for idx, _target in enumerate(target):
            if _target['boxes'].shape[0] > 0 and torch.any(obj_mask[idx]):
                bbox_loss += ciou_loss(pred_boxes[idx, obj_mask[idx]], _target['boxes']).min(dim=1).values.sum()

        return dict(obj_loss=F.mse_loss(b_o[obj_mask], target_o[obj_mask], reduction='sum'),
                    noobj_loss=self.lambda_noobj * b_o[noobj_mask].pow(2).sum(),
                    bbox_loss=self.lambda_coords * bbox_loss,
                    clf_loss=F.binary_cross_entropy(b_scores[obj_mask], target_scores[obj_mask], reduction='sum'))

    def forward(
        self,
        output: Tensor,
        target: Optional[List[Dict[str, Tensor]]] = None
    ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]:
        """Perform detection on an image tensor and returns either the loss dictionary in training mode
        or the list of detections in eval mode.

        Args:
            x (torch.Tensor[N, 3, H, W]): input image tensor
            target (list<dict>, optional): each dict must have two keys `boxes` of type torch.Tensor[-1, 4]
            and `labels` of type torch.Tensor[-1]
        """

        if self.training and target is None:
            raise ValueError("`target` needs to be specified in training mode")

        pred_boxes, b_o, b_scores = self._format_outputs(output)

        if self.training:
            return self._compute_losses(pred_boxes, b_o, b_scores, target)  # type: ignore[arg-type]
        else:
            # cf. https://github.com/Tianxiaomo/pytorch-YOLOv4/blob/master/tool/yolo_layer.py#L117
            return self.post_process(pred_boxes, b_o, b_scores, self.rpn_nms_thresh, self.box_score_thresh)


class Yolov4Head(nn.Module):
    def __init__(
        self,
        num_classes: int = 80,
        anchors: Optional[Tensor] = None,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        # cf. https://github.com/AlexeyAB/darknet/blob/master/cfg/yolov4.cfg#L1143
        if anchors is None:
            anchors = torch.tensor([[[12, 16], [19, 36], [40, 28]],
                                   [[36, 75], [76, 55], [72, 146]],
                                   [[142, 110], [192, 243], [459, 401]]], dtype=torch.float32) / 608
        elif not isinstance(anchors, torch.Tensor):
            anchors = torch.tensor(anchors, dtype=torch.float32)

        if anchors.shape[0] != 3:
            raise AssertionError(f"The number of anchors is expected to be 3. received: {anchors.shape[0]}")

        super().__init__()

        self.head1 = nn.Sequential(
            *conv_sequence(128, 256, act_layer, norm_layer, None, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(256, (5 + num_classes) * 3, None, None, None, conv_layer,
                           kernel_size=1, bias=True))

        self.yolo1 = YoloLayer(anchors[0], num_classes=num_classes, scale_xy=1.2)

        self.pre_head2 = nn.Sequential(*conv_sequence(128, 256, act_layer, norm_layer, drop_layer, conv_layer,
                                                      kernel_size=3, padding=1, stride=2, bias=False))
        self.head2_1 = nn.Sequential(
            *conv_sequence(512, 256, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(256, 512, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(512, 256, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(256, 512, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(512, 256, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False))
        self.head2_2 = nn.Sequential(
            *conv_sequence(256, 512, act_layer, norm_layer, None, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(512, (5 + num_classes) * 3, None, None, None, conv_layer,
                           kernel_size=1, bias=True))

        self.yolo2 = YoloLayer(anchors[1], num_classes=num_classes, scale_xy=1.1)

        self.pre_head3 = nn.Sequential(*conv_sequence(256, 512, act_layer, norm_layer, drop_layer, conv_layer,
                                                      kernel_size=3, padding=1, stride=2, bias=False))
        self.head3 = nn.Sequential(
            *conv_sequence(1024, 512, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(512, 1024, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(1024, 512, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(512, 1024, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(1024, 512, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=1, bias=False),
            *conv_sequence(512, 1024, act_layer, norm_layer, drop_layer, conv_layer,
                           kernel_size=3, padding=1, bias=False),
            *conv_sequence(1024, (5 + num_classes) * 3, None, None, None, conv_layer,
                           kernel_size=1, bias=True))

        self.yolo3 = YoloLayer(anchors[2], num_classes=num_classes, scale_xy=1.05)
        init_module(self, 'leaky_relu')
        # Zero init
        self.head1[-1].weight.data.zero_()
        self.head1[-1].bias.data.zero_()
        self.head2_2[-1].weight.data.zero_()
        self.head2_2[-1].bias.data.zero_()
        self.head3[-1].weight.data.zero_()
        self.head3[-1].bias.data.zero_()

    def forward(
        self,
        feats: List[Tensor],
        target: Optional[List[Dict[str, Tensor]]] = None
    ) -> Union[List[Dict[str, Tensor]], Dict[str, Tensor]]:
        o1 = self.head1(feats[0])

        h2 = self.pre_head2(feats[0])
        h2 = torch.cat([h2, feats[1]], dim=1)
        h2 = self.head2_1(h2)
        o2 = self.head2_2(h2)

        h3 = self.pre_head3(h2)
        h3 = torch.cat([h3, feats[2]], dim=1)
        o3 = self.head3(h3)

        # YOLO output
        y1 = self.yolo1(o1, target)
        y2 = self.yolo2(o2, target)
        y3 = self.yolo3(o3, target)

        if not self.training:

            detections = [dict(boxes=torch.cat((det1['boxes'], det2['boxes'], det3['boxes']), dim=0),
                               scores=torch.cat((det1['scores'], det2['scores'], det3['scores']), dim=0),
                               labels=torch.cat((det1['labels'], det2['labels'], det3['labels']), dim=0))
                          for det1, det2, det3 in zip(y1, y2, y3)]
            return detections

        else:

            return {k: y1[k] + y2[k] + y3[k] for k in y1.keys()}


class YOLOv4(nn.Module):
    def __init__(
        self,
        layout: List[Tuple[int, int]],
        num_classes: int = 80,
        in_channels: int = 3,
        stem_channels: int = 32,
        anchors: Optional[Tensor] = None,
        act_layer: Optional[nn.Module] = None,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        drop_layer: Optional[Callable[..., nn.Module]] = None,
        conv_layer: Optional[Callable[..., nn.Module]] = None,
        backbone_norm_layer: Optional[Callable[[int], nn.Module]] = None
    ) -> None:
        super().__init__()

        if act_layer is None:
            act_layer = nn.LeakyReLU(0.1, inplace=True)
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if backbone_norm_layer is None:
            backbone_norm_layer = norm_layer

        # backbone
        self.backbone = DarknetBodyV4(layout, in_channels, stem_channels, 3, Mish(),
                                      backbone_norm_layer, drop_layer, conv_layer)
        # neck
        self.neck = Neck([1024, 512, 256], act_layer, norm_layer, drop_layer, conv_layer)
        # head
        self.head = Yolov4Head(num_classes, anchors, act_layer, norm_layer, drop_layer, conv_layer)

        init_module(self.neck, 'leaky_relu')
        init_module(self.head, 'leaky_relu')

    def forward(
        self,
        x: Tensor,
        target: Optional[List[Dict[str, Tensor]]] = None
    ) -> Union[List[Dict[str, Tensor]], Dict[str, Tensor]]:

        if not isinstance(x, torch.Tensor):
            x = torch.stack(x, dim=0)

        out = self.backbone(x)

        x20, x13, x6 = self.neck(out)

        return self.head((x20, x13, x6), target)


def _yolo(arch: str, pretrained: bool, progress: bool, pretrained_backbone: bool, **kwargs: Any) -> YOLOv4:

    if pretrained:
        pretrained_backbone = False

    # Build the model
    model = YOLOv4(default_cfgs[arch]['backbone']['layout'], **kwargs)  # type: ignore[index]
    # Load backbone pretrained parameters
    if pretrained_backbone:
        load_pretrained_params(model.backbone, default_cfgs[arch]['backbone']['url'], progress,  # type: ignore[index]
                               key_replacement=('features.', ''), key_filter='features.')
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, default_cfgs[arch]['url'], progress)  # type: ignore[arg-type]

    return model


[docs] def yolov4(pretrained: bool = False, progress: bool = True, pretrained_backbone: bool = True, **kwargs: Any) -> YOLOv4: """YOLOv4 model from `"YOLOv4: Optimal Speed and Accuracy of Object Detection" <https://arxiv.org/pdf/2004.10934.pdf>`_. YOLOv4 is an improvement on YOLOv3 that includes many changes including: the usage of `DropBlock <https://arxiv.org/pdf/1810.12890.pdf>`_ regularization, `Mish <https://arxiv.org/pdf/1908.08681.pdf>`_ activation, `CSP <https://arxiv.org/pdf/2004.10934.pdf>`_ and `SAM <https://arxiv.org/pdf/1807.06521.pdf>`_ in the backbone, `SPP <https://arxiv.org/pdf/1406.4729.pdf>`_ and `PAN <https://arxiv.org/pdf/1803.01534.pdf>`_ in the neck. Args: pretrained (bool, optional): If True, returns a model pre-trained on ImageNet progress (bool, optional): If True, displays a progress bar of the download to stderr pretrained_backbone (bool, optional): If True, backbone parameters will have been pretrained on Imagenette Returns: torch.nn.Module: detection module """ if pretrained_backbone: kwargs['backbone_norm_layer'] = FrozenBatchNorm2d return _yolo('yolov4', pretrained, progress, pretrained_backbone, **kwargs) # type: ignore[return-value]