Skip to content

TorchCAM: class activation explorer

TorchCAM provides a minimal yet flexible way to explore the spatial importance of features on your PyTorch model outputs. Check out the live demo on HuggingFace Spaces 🤗

CAM visualization

Source: image from woopets (activation maps created with a pretrained Resnet-18)

This project is meant for:

  • exploration: easily assess the influence of spatial features on your model's outputs
  • 👩‍🔬 research: quickly implement your own ideas for new CAM methods

Installation

Create and activate a virtual environment and then install TorchCAM:

pip install torchcam

Check out the installation guide for more options

Quick start

Get an image and a model:

from torchvision.io import decode_image
from torchvision.models import get_model, get_model_weights

weights = get_model_weights("resnet18").DEFAULT
model = get_model("resnet18", weights=weights).eval()
preprocess = weights.transforms()

img_path = "path/to/your/image.jpg"

img = decode_image(img_path)
input_tensor = preprocess(img)

Compute the class activation map:

from torchcam.methods import LayerCAM

with LayerCAM(model) as cam_extractor:
  out = model(input_tensor.unsqueeze(0))
  # Retrieve the CAM by passing the class index and the model output
  activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)

Display it:

import matplotlib.pyplot as plt
from torchvision.transforms.v2.functional import to_pil_image
from torchcam.utils import overlay_mask

# Resize the CAM and overlay it
result = overlay_mask(to_pil_image(img), to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=0.5)
plt.imshow(result); plt.axis('off'); plt.tight_layout(); plt.show()

overlayed_heatmap

CAM zoo

Activation-based methods

Gradient-based methods