Source code for torchcam.utils
# Copyright (C) 2020-2023, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
import numpy as np
from matplotlib import colormaps as cm
from PIL import Image
[docs]
def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = "jet", alpha: float = 0.7) -> Image.Image:
"""Overlay a colormapped mask on a background image
>>> from PIL import Image
>>> import matplotlib.pyplot as plt
>>> from torchcam.utils import overlay_mask
>>> img = ...
>>> cam = ...
>>> overlay = overlay_mask(img, cam)
Args:
img: background image
mask: mask to be overlayed in grayscale
colormap: colormap to be applied on the mask
alpha: transparency of the background image
Returns:
overlayed image
Raises:
TypeError: when the arguments have invalid types
ValueError: when the alpha argument has an incorrect value
"""
if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
raise TypeError("img and mask arguments need to be PIL.Image")
if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
raise ValueError("alpha argument is expected to be of type float between 0 and 1")
cmap = cm.get_cmap(colormap)
# Resize mask and apply colormap
overlay = mask.resize(img.size, resample=Image.BICUBIC)
overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)
# Overlay the image with the mask
overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))
return overlayed_img