Source code for torchcam.utils
# Copyright (C) 2020-2024, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
from typing import cast
import numpy as np
from matplotlib import colormaps as cm
from PIL.Image import Image, Resampling, fromarray
[docs]
def overlay_mask(img: Image, mask: Image, colormap: str = "jet", alpha: float = 0.7) -> Image:
"""Overlay a colormapped mask on a background image
>>> from PIL import Image
>>> import matplotlib.pyplot as plt
>>> from torchcam.utils import overlay_mask
>>> img = ...
>>> cam = ...
>>> overlay = overlay_mask(img, cam)
Args:
img: background image
mask: mask to be overlayed in grayscale
colormap: colormap to be applied on the mask
alpha: transparency of the background image
Returns:
overlayed image
Raises:
TypeError: when the arguments have invalid types
ValueError: when the alpha argument has an incorrect value
"""
if not isinstance(img, Image) or not isinstance(mask, Image):
raise TypeError("img and mask arguments need to be PIL.Image")
if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
raise ValueError("alpha argument is expected to be of type float between 0 and 1")
cmap = cm.get_cmap(colormap)
# Resize mask and apply colormap
overlay = mask.resize(img.size, resample=Resampling.BICUBIC)
overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)
# Overlay the image with the mask
return fromarray((alpha * np.asarray(img) + (1 - alpha) * cast(np.ndarray, overlay)).astype(np.uint8))