Implements Average Drop and Increase in Confidence from "Grad-CAM++: Improved Visual Explanations for Deep
Convolutional Networks.".
The raw aggregated metric is computed as follows:
\[
\forall N, H, W \in \mathbb{N}, \forall X \in \mathbb{R}^{N \times 3 \times H \times W},
\forall m \in \mathcal{M}, \forall c \in \mathcal{C}, \\
AvgDrop_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N f_{m, c}(X_i) \\
IncrConf_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N g_{m, c}(X_i)
\]
where \(\mathcal{C}\) is the set of class activation generators,
\(\mathcal{M}\) is the set of classification models,
with the function \(f_{m, c}\) defined as:
\[
\forall x \in \mathbb{R}^{3 \times H \times W},
f_{m, c}(x) = \frac{\max(0, m(x) - m(E_{m, c}(x) * x))}{m(x)}
\]
where \(E_{m, c}(x)\) is the class activation map of \(m\) for input \(x\) with method \(m\),
resized to (H, W),
and with the function \(g_{m, c}\) defined as:
\[
\forall x \in \mathbb{R}^{3 \times H \times W},\quad
g_{m, c}(x) =
\begin{cases}
1 & \text{if } m(x) < m(E_{m, c}(x) \cdot x) \\
0 & \text{otherwise}
\end{cases}
\]
Example
from functools import partial
from torchcam.metrics import ClassificationMetric
metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1))
metric.update(input_tensor)
metric.summary()
Source code in torchcam/metrics.py
| def __init__(
self,
cam_extractor: _CAM,
logits_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
) -> None:
# This is a typa, I don't know how to rites
self.cam_extractor = cam_extractor
self.logits_fn = logits_fn
self.reset()
|
torchcam.metrics.ClassificationMetric.reset
Reset the state of the metric.
Source code in torchcam/metrics.py
| def reset(self) -> None:
"""Reset the state of the metric."""
self.drop = 0.0
self.increase = 0.0
self.total = 0
self.nan_count = 0
|
torchcam.metrics.ClassificationMetric.update
update(input_tensor: Tensor, class_idx: int | None = None) -> None
Update the state of the metric with new predictions.
| PARAMETER |
DESCRIPTION |
input_tensor
|
preprocessed input tensor for the model
TYPE:
Tensor
|
class_idx
|
class index to focus on (default: index of the top predicted class for each sample)
TYPE:
int | None
DEFAULT:
None
|
Source code in torchcam/metrics.py
| def update(
self,
input_tensor: torch.Tensor,
class_idx: int | None = None,
) -> None:
"""Update the state of the metric with new predictions.
Args:
input_tensor: preprocessed input tensor for the model
class_idx: class index to focus on (default: index of the top predicted class for each sample)
"""
self.cam_extractor.model.eval()
probs = self._get_probs(input_tensor)
# Take the top preds for the cam
if isinstance(class_idx, int):
cams = self.cam_extractor(class_idx, probs)
cam = self.cam_extractor.fuse_cams(cams)
probs = probs[:, class_idx]
else:
preds = probs.argmax(dim=-1)
cams = self.cam_extractor(preds.cpu().numpy().tolist(), probs)
cam = self.cam_extractor.fuse_cams(cams)
probs = probs.gather(1, preds.unsqueeze(1)).squeeze(1)
self.cam_extractor.disable_hooks()
# Safeguard: skip NaNs
discard = torch.isnan(cam).reshape(input_tensor.shape[0], -1).any(dim=-1)
cam = cam[~discard, ...]
probs = probs[~discard]
if class_idx is None:
preds = preds[~discard]
input_tensor = input_tensor[~discard]
# Resize the CAM
cam = torch.nn.functional.interpolate(cam.unsqueeze(1), input_tensor.shape[-2:], mode="bilinear")
# Create the explanation map & get the new probs
with torch.inference_mode():
masked_probs = self._get_probs(cam * input_tensor)
masked_probs = (
masked_probs[:, class_idx]
if isinstance(class_idx, int)
else masked_probs.gather(1, preds.unsqueeze(1)).squeeze(1)
)
# Drop (avoid division by zero)
drop = torch.relu(probs - masked_probs).div(probs + 1e-7)
# Increase
increase = probs < masked_probs
self.cam_extractor.enable_hooks()
self.drop += drop.sum().item()
self.increase += increase.sum().item()
self.total += cam.shape[0]
self.nan_count += discard.sum().item()
|
torchcam.metrics.ClassificationMetric.summary
Computes the aggregated metrics.
| RETURNS |
DESCRIPTION |
dict[str, float]
|
a dictionary with the average drop and the increase in confidence
|
Source code in torchcam/metrics.py
| def summary(self) -> dict[str, float]:
"""Computes the aggregated metrics.
Returns:
a dictionary with the average drop and the increase in confidence
Raises:
AssertionError: if the metric has not been updated
"""
if self.total == 0:
raise AssertionError("you need to update the metric before getting the summary")
return {
"avg_drop": self.drop / self.total,
"conf_increase": self.increase / self.total,
}
|