Skip to content

Utilities

overlay_mask

overlay_mask(img: Image, mask: Image, colormap: Colormap | str = 'jet', alpha: float = 0.7) -> Image

Overlay a colormapped mask on a background image.

Example
from PIL import Image
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask
img = ...
cam = ...
overlay = overlay_mask(img, cam)
PARAMETER DESCRIPTION
img

background image

TYPE: Image

mask

mask to be overlayed in grayscale

TYPE: Image

colormap

colormap to be applied on the mask

TYPE: Colormap | str DEFAULT: 'jet'

alpha

transparency of the background image

TYPE: float DEFAULT: 0.7

RETURNS DESCRIPTION
Image

overlayed image

RAISES DESCRIPTION
TypeError

when the arguments have invalid types

ValueError

when the alpha argument has an incorrect value

Source code in torchcam/utils.py
def overlay_mask(img: Image, mask: Image, colormap: Colormap | str = "jet", alpha: float = 0.7) -> Image:
    """Overlay a colormapped mask on a background image.

    Example:
        ```python
        from PIL import Image
        import matplotlib.pyplot as plt
        from torchcam.utils import overlay_mask
        img = ...
        cam = ...
        overlay = overlay_mask(img, cam)
        ```

    Args:
        img: background image
        mask: mask to be overlayed in grayscale
        colormap: colormap to be applied on the mask
        alpha: transparency of the background image

    Returns:
        overlayed image

    Raises:
        TypeError: when the arguments have invalid types
        ValueError: when the alpha argument has an incorrect value
    """
    if not isinstance(img, Image) or not isinstance(mask, Image):
        raise TypeError("img and mask arguments need to be PIL.Image")

    if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
        raise ValueError("alpha argument is expected to be of type float between 0 and 1")

    if len(img.getbands()) not in {1, 3}:
        raise ValueError("img argument needs to be a grayscale or RGB image")

    cmap = get_cmap(colormap)
    # Resize mask and apply colormap
    overlay = mask.resize(img.size, resample=Resampling.BICUBIC)
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)
    # Overlay the image with the mask
    bg_img = np.asarray(img) if len(img.getbands()) == 3 else np.asarray(img)[..., np.newaxis].repeat(3, axis=-1)
    return fromarray((alpha * bg_img + (1 - alpha) * overlay).astype(np.uint8))