torchcam.methods#

Class activation map#

The class activation map gives you the importance of each region of a feature map on a model’s output. More specifically, a class activation map is relative to:

  • the layer at which it is computed (e.g. the N-th layer of your model)

  • the model’s classification output (e.g. the raw logits of the model)

  • the class index to focus on

With TorchCAM, the target layer is selected when you create your CAM extractor. You will need to pass the model logits to the extractor and a class index for it to do its magic!

Activation-based methods#

Methods related to activation-based class activation maps.

class torchcam.methods.CAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, fc_layer: Module | str | None = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “Learning Deep Features for Discriminative Localization”.

The Class Activation Map (CAM) is defined for image classification models that have global pooling at the end of the visual feature extraction block. The localization map is computed as follows:

\[L^{(c)}_{CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), and \(w_k^{(c)}\) is the weight corresponding to class \(c\) for unit \(k\) in the fully connected layer..

>>> from torchvision.models import resnet18
>>> from torchcam.methods import CAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = CAM(model, 'layer4', 'fc')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • fc_layer – either the fully connected layer itself or its name

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.ScoreCAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, batch_size: int = 32, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks”.

The localization map is computed as follows:

\[L^{(c)}_{Score-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = softmax\Big(Y^{(c)}(M_k) - Y^{(c)}(X_b)\Big)_k\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), \(Y^{(c)}(X)\) is the model output score for class \(c\) before softmax for input \(X\), \(X_b\) is a baseline image, and \(M_k\) is defined as follows:

\[M_k = \frac{U(A_k) - \min\limits_m U(A_m)}{\max\limits_m U(A_m) - \min\limits_m U(A_m)}) \odot X_b\]

where \(\odot\) refers to the element-wise multiplication and \(U\) is the upsampling operation.

>>> from torchvision.models import resnet18
>>> from torchcam.methods import ScoreCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = ScoreCAM(model, 'layer4')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • batch_size – batch size used to forward masked inputs

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.SSCAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, batch_size: int = 32, num_samples: int = 35, std: float = 2.0, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “SS-CAM: Smoothed Score-CAM for Sharper Visual Feature Localization”.

The localization map is computed as follows:

\[L^{(c)}_{SS-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = softmax\Big(\frac{1}{N} \sum\limits_{i=1}^N (Y^{(c)}(\hat{M_k}) - Y^{(c)}(X_b))\Big)_k\]

where \(N\) is the number of samples used to smooth the weights, \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), \(Y^{(c)}(X)\) is the model output score for class \(c\) before softmax for input \(X\), \(X_b\) is a baseline image, and \(M_k\) is defined as follows:

\[\hat{M_k} = \Bigg(\frac{U(A_k) - \min\limits_m U(A_m)}{\max\limits_m U(A_m) - \min\limits_m U(A_m)} + \delta\Bigg) \odot X_b\]

where \(\odot\) refers to the element-wise multiplication, \(U\) is the upsampling operation, \(\delta \sim \mathcal{N}(0, \sigma^2)\) is the random noise that follows a 0-mean gaussian distribution with a standard deviation of \(\sigma\).

>>> from torchvision.models import resnet18
>>> from torchcam.methods import SSCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = SSCAM(model, 'layer4')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • batch_size – batch size used to forward masked inputs

  • num_samples – number of noisy samples used for weight computation

  • std – standard deviation of the noise added to the normalized activation

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.ISCAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, batch_size: int = 32, num_samples: int = 10, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “IS-CAM: Integrated Score-CAM for axiomatic-based explanations”.

The localization map is computed as follows:

\[L^{(c)}_{ISS-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = softmax\Bigg(\frac{1}{N} \sum\limits_{i=1}^N \Big(Y^{(c)}(M_i) - Y^{(c)}(X_b)\Big)\Bigg)_k\]

where \(N\) is the number of samples used to smooth the weights, \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), \(Y^{(c)}(X)\) is the model output score for class \(c\) before softmax for input \(X\), \(X_b\) is a baseline image, and \(M_i\) is defined as follows:

\[M_i = \sum\limits_{j=0}^{i-1} \frac{j}{N} \frac{U(A_k) - \min\limits_m U(A_m)}{\max\limits_m U(A_m) - \min\limits_m U(A_m)} \odot X_b\]

where \(\odot\) refers to the element-wise multiplication, \(U\) is the upsampling operation.

>>> from torchvision.models import resnet18
>>> from torchcam.methods import ISSCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = ISCAM(model, 'layer4')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • batch_size – batch size used to forward masked inputs

  • num_samples – number of noisy samples used for weight computation

  • input_shape – shape of the expected input tensor excluding the batch dimension

Gradient-based methods#

Methods related to gradient-based class activation maps.

class torchcam.methods.GradCAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization”.

The localization map is computed as follows:

\[L^{(c)}_{Grad-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = \frac{1}{H \cdot W} \sum\limits_{i=1}^H \sum\limits_{j=1}^W \frac{\partial Y^{(c)}}{\partial A_k(i, j)}\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), and \(Y^{(c)}\) is the model output score for class \(c\) before softmax.

>>> from torchvision.models import resnet18
>>> from torchcam.methods import GradCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = GradCAM(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100, scores=scores)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.GradCAMpp(model: Module, target_layer: Module | str | List[Module | str] | None = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks”.

The localization map is computed as follows:

\[L^{(c)}_{Grad-CAM++}(x, y) = \sum\limits_k w_k^{(c)} A_k(x, y)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = \sum\limits_{i=1}^H \sum\limits_{j=1}^W \alpha_k^{(c)}(i, j) \cdot ReLU\Big(\frac{\partial Y^{(c)}}{\partial A_k(i, j)}\Big)\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), \(Y^{(c)}\) is the model output score for class \(c\) before softmax, and \(\alpha_k^{(c)}(i, j)\) being defined as:

\[\alpha_k^{(c)}(i, j) = \frac{1}{\sum\limits_{i, j} \frac{\partial Y^{(c)}}{\partial A_k(i, j)}} = \frac{\frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2}}{2 \cdot \frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2} + \sum\limits_{a,b} A_k (a,b) \cdot \frac{\partial^3 Y^{(c)}}{(\partial A_k(i,j))^3}}\]

if \(\frac{\partial Y^{(c)}}{\partial A_k(i, j)} = 1\) else \(0\).

>>> from torchvision.models import resnet18
>>> from torchcam.methods import GradCAMpp
>>> model = resnet18(pretrained=True).eval()
>>> cam = GradCAMpp(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100, scores=scores)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.SmoothGradCAMpp(model: Module, target_layer: Module | str | List[Module | str] | None = None, num_samples: int = 4, std: float = 0.3, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “Smooth Grad-CAM++: An Enhanced Inference Level Visualization Technique for Deep Convolutional Neural Network Models” with a personal correction to the paper (alpha coefficient numerator).

The localization map is computed as follows:

\[L^{(c)}_{Smooth Grad-CAM++}(x, y) = \sum\limits_k w_k^{(c)} A_k(x, y)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = \sum\limits_{i=1}^H \sum\limits_{j=1}^W \alpha_k^{(c)}(i, j) \cdot ReLU\Big(\frac{\partial Y^{(c)}}{\partial A_k(i, j)}\Big)\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), \(Y^{(c)}\) is the model output score for class \(c\) before softmax, and \(\alpha_k^{(c)}(i, j)\) being defined as:

\[\alpha_k^{(c)}(i, j) = \frac{\frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2}}{2 \cdot \frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2} + \sum\limits_{a,b} A_k (a,b) \cdot \frac{\partial^3 Y^{(c)}}{(\partial A_k(i,j))^3}} = \frac{\frac{1}{n} \sum\limits_{m=1}^n D^{(c, 2)}_k(i, j)}{ \frac{2}{n} \sum\limits_{m=1}^n D^{(c, 2)}_k(i, j) + \sum\limits_{a,b} A_k (a,b) \cdot \frac{1}{n} \sum\limits_{m=1}^n D^{(c, 3)}_k(i, j)}\]

if \(\frac{\partial Y^{(c)}}{\partial A_k(i, j)} = 1\) else \(0\). Here \(D^{(c, p)}_k(i, j)\) refers to the p-th partial derivative of the class score of class \(c\) relatively to the activation in layer \(k\) at position \((i, j)\), and \(n\) is the number of samples used to get the gradient estimate.

Please note the difference in the numerator of \(\alpha_k^{(c)}(i, j)\), which is actually \(\frac{1}{n} \sum\limits_{k=1}^n D^{(c, 1)}_k(i,j)\) in the paper.

>>> from torchvision.models import resnet18
>>> from torchcam.methods import SmoothGradCAMpp
>>> model = resnet18(pretrained=True).eval()
>>> cam = SmoothGradCAMpp(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • num_samples – number of samples to use for smoothing

  • std – standard deviation of the noise

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.XGradCAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “Axiom-based Grad-CAM: Towards Accurate Visualization and Explanation of CNNs”.

The localization map is computed as follows:

\[L^{(c)}_{XGrad-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)\]

with the coefficient \(w_k^{(c)}\) being defined as:

\[w_k^{(c)} = \sum\limits_{i=1}^H \sum\limits_{j=1}^W \Big( \frac{\partial Y^{(c)}}{\partial A_k(i, j)} \cdot \frac{A_k(i, j)}{\sum\limits_{m=1}^H \sum\limits_{n=1}^W A_k(m, n)} \Big)\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), and \(Y^{(c)}\) is the model output score for class \(c\) before softmax.

>>> from torchvision.models import resnet18
>>> from torchcam.methods import XGradCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = XGradCAM(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100, scores=scores)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • input_shape – shape of the expected input tensor excluding the batch dimension

class torchcam.methods.LayerCAM(model: Module, target_layer: Module | str | List[Module | str] | None = None, input_shape: Tuple[int, ...] = (3, 224, 224), **kwargs: Any)[source]#

Implements a class activation map extractor as described in “LayerCAM: Exploring Hierarchical Class Activation Maps for Localization”.

The localization map is computed as follows:

\[L^{(c)}_{Layer-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)}(x, y) \cdot A_k(x, y)\Big)\]

with the coefficient \(w_k^{(c)}(x, y)\) being defined as:

\[w_k^{(c)}(x, y) = ReLU\Big(\frac{\partial Y^{(c)}}{\partial A_k(i, j)}(x, y)\Big)\]

where \(A_k(x, y)\) is the activation of node \(k\) in the target layer of the model at position \((x, y)\), and \(Y^{(c)}\) is the model output score for class \(c\) before softmax.

>>> from torchvision.models import resnet18
>>> from torchcam.methods import LayerCAM
>>> model = resnet18(pretrained=True).eval()
>>> extractor = LayerCAM(model, 'layer4')
>>> scores = model(input_tensor)
>>> cams = extractor(class_idx=100, scores=scores)
>>> fused_cam = extractor.fuse_cams(cams)
Parameters:
  • model – input model

  • target_layer – either the target layer itself or its name, or a list of those

  • input_shape – shape of the expected input tensor excluding the batch dimension

classmethod fuse_cams(cams: List[Tensor], target_shape: Tuple[int, int] | None = None) Tensor#

Fuse class activation maps from different layers.

Parameters:
  • cams – the list of activation maps (for the same input)

  • target_shape – expected spatial shape of the fused activation map (default to the biggest spatial shape among input maps)

Returns:

fused class activation map

Return type:

torch.Tensor