Skip to content

Evaluation metrics

Apart from qualitative visual comparison, it is important to have a refined evaluation metric for class activation maps. This submodule is dedicated to the evaluation of CAM methods.

ClassificationMetric

ClassificationMetric(cam_extractor: _CAM, logits_fn: Callable[[Tensor], Tensor] | None = None)

Implements Average Drop and Increase in Confidence from "Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks.".

The raw aggregated metric is computed as follows:

\[ \forall N, H, W \in \mathbb{N}, \forall X \in \mathbb{R}^{N \times 3 \times H \times W}, \forall m \in \mathcal{M}, \forall c \in \mathcal{C}, \\ AvgDrop_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N f_{m, c}(X_i) \\ IncrConf_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N g_{m, c}(X_i) \]

where \(\mathcal{C}\) is the set of class activation generators, \(\mathcal{M}\) is the set of classification models, with the function \(f_{m, c}\) defined as:

\[ \forall x \in \mathbb{R}^{3 \times H \times W}, f_{m, c}(x) = \frac{\max(0, m(x) - m(E_{m, c}(x) * x))}{m(x)} \]

where \(E_{m, c}(x)\) is the class activation map of \(m\) for input \(x\) with method \(m\), resized to (H, W),

and with the function \(g_{m, c}\) defined as:

\[ \forall x \in \mathbb{R}^{3 \times H \times W},\quad g_{m, c}(x) = \begin{cases} 1 & \text{if } m(x) < m(E_{m, c}(x) \cdot x) \\ 0 & \text{otherwise} \end{cases} \]
Example
from functools import partial
from torchcam.metrics import ClassificationMetric
metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1))
metric.update(input_tensor)
metric.summary()
Source code in torchcam/metrics.py
def __init__(
    self,
    cam_extractor: _CAM,
    logits_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
) -> None:
    # This is a typa, I don't know how to rites
    self.cam_extractor = cam_extractor
    self.logits_fn = logits_fn
    self.reset()

torchcam.metrics.ClassificationMetric.reset

reset() -> None

Reset the state of the metric.

Source code in torchcam/metrics.py
def reset(self) -> None:
    """Reset the state of the metric."""
    self.drop = 0.0
    self.increase = 0.0
    self.total = 0
    self.nan_count = 0

torchcam.metrics.ClassificationMetric.update

update(input_tensor: Tensor, class_idx: int | None = None) -> None

Update the state of the metric with new predictions.

PARAMETER DESCRIPTION
input_tensor

preprocessed input tensor for the model

TYPE: Tensor

class_idx

class index to focus on (default: index of the top predicted class for each sample)

TYPE: int | None DEFAULT: None

Source code in torchcam/metrics.py
def update(
    self,
    input_tensor: torch.Tensor,
    class_idx: int | None = None,
) -> None:
    """Update the state of the metric with new predictions.

    Args:
        input_tensor: preprocessed input tensor for the model
        class_idx: class index to focus on (default: index of the top predicted class for each sample)
    """
    self.cam_extractor.model.eval()
    probs = self._get_probs(input_tensor)
    # Take the top preds for the cam
    if isinstance(class_idx, int):
        cams = self.cam_extractor(class_idx, probs)
        cam = self.cam_extractor.fuse_cams(cams)
        probs = probs[:, class_idx]
    else:
        preds = probs.argmax(dim=-1)
        cams = self.cam_extractor(preds.cpu().numpy().tolist(), probs)
        cam = self.cam_extractor.fuse_cams(cams)
        probs = probs.gather(1, preds.unsqueeze(1)).squeeze(1)
    self.cam_extractor.disable_hooks()
    # Safeguard: skip NaNs
    discard = torch.isnan(cam).reshape(input_tensor.shape[0], -1).any(dim=-1)
    cam = cam[~discard, ...]
    probs = probs[~discard]
    if class_idx is None:
        preds = preds[~discard]
    input_tensor = input_tensor[~discard]
    # Resize the CAM
    cam = torch.nn.functional.interpolate(cam.unsqueeze(1), input_tensor.shape[-2:], mode="bilinear")
    # Create the explanation map & get the new probs
    with torch.inference_mode():
        masked_probs = self._get_probs(cam * input_tensor)
    masked_probs = (
        masked_probs[:, class_idx]
        if isinstance(class_idx, int)
        else masked_probs.gather(1, preds.unsqueeze(1)).squeeze(1)
    )
    # Drop (avoid division by zero)
    drop = torch.relu(probs - masked_probs).div(probs + 1e-7)

    # Increase
    increase = probs < masked_probs

    self.cam_extractor.enable_hooks()

    self.drop += drop.sum().item()
    self.increase += increase.sum().item()
    self.total += cam.shape[0]
    self.nan_count += discard.sum().item()

torchcam.metrics.ClassificationMetric.summary

summary() -> dict[str, float]

Computes the aggregated metrics.

RETURNS DESCRIPTION
dict[str, float]

a dictionary with the average drop and the increase in confidence

RAISES DESCRIPTION
AssertionError

if the metric has not been updated

Source code in torchcam/metrics.py
def summary(self) -> dict[str, float]:
    """Computes the aggregated metrics.

    Returns:
        a dictionary with the average drop and the increase in confidence

    Raises:
        AssertionError: if the metric has not been updated
    """
    if self.total == 0:
        raise AssertionError("you need to update the metric before getting the summary")

    return {
        "avg_drop": self.drop / self.total,
        "conf_increase": self.increase / self.total,
    }