# Copyright (C) 2022-2023, François-Guillaume Fernandez.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.fromtypingimportCallable,Dict,Union,castimporttorchfrom.methods.coreimport_CAM
[docs]classClassificationMetric:r"""Implements Average Drop and Increase in Confidence from `"Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks." <https://arxiv.org/pdf/1710.11063.pdf>`_. The raw aggregated metric is computed as follows: .. math:: \forall N, H, W \in \mathbb{N}, \forall X \in \mathbb{R}^{N*3*H*W}, \forall m \in \mathcal{M}, \forall c \in \mathcal{C}, \\ AvgDrop_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N f_{m, c}(X_i) \\ IncrConf_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N g_{m, c}(X_i) where :math:`\mathcal{C}` is the set of class activation generators, :math:`\mathcal{M}` is the set of classification models, with the function :math:`f_{m, c}` defined as: .. math:: \forall x \in \mathbb{R}^{3*H*W}, f_{m, c}(x) = \frac{\max(0, m(x) - m(E_{m, c}(x) * x))}{m(x)} where :math:`E_{m, c}(x)` is the class activation map of :math:`m` for input :math:`x` with method :math:`m`, resized to (H, W), and with the function :math:`g_{m, c}` defined as: .. math:: \forall x \in \mathbb{R}^{3*H*W}, g_{m, c}(x) = \left\{ \begin{array}{ll} 1 & \mbox{if } m(x) < m(E_{m, c}(x) * x) \\ 0 & \mbox{otherwise.} \end{array} \right. >>> from functools import partial >>> from torchcam.metrics import ClassificationMetric >>> metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1)) >>> metric.update(input_tensor) >>> metric.summary() """def__init__(self,cam_extractor:_CAM,logits_fn:Union[Callable[[torch.Tensor],torch.Tensor],None]=None,)->None:self.cam_extractor=cam_extractorself.logits_fn=logits_fnself.reset()def_get_probs(self,input_tensor:torch.Tensor)->torch.Tensor:logits=self.cam_extractor.model(input_tensor)returncast(torch.Tensor,logitsifself.logits_fnisNoneelseself.logits_fn(logits))
[docs]defupdate(self,input_tensor:torch.Tensor,class_idx:Union[int,None]=None,)->None:"""Update the state of the metric with new predictions Args: input_tensor: preprocessed input tensor for the model class_idx: class index to focus on (default: index of the top predicted class for each sample) """self.cam_extractor.model.eval()probs=self._get_probs(input_tensor)# Take the top preds for the camifisinstance(class_idx,int):cams=self.cam_extractor(class_idx,probs)cam=self.cam_extractor.fuse_cams(cams)probs=probs[:,class_idx]else:preds=probs.argmax(dim=-1)cams=self.cam_extractor(preds.cpu().numpy().tolist(),probs)cam=self.cam_extractor.fuse_cams(cams)probs=probs.gather(1,preds.unsqueeze(1)).squeeze(1)self.cam_extractor._hooks_enabled=False# Safeguard: replace NaNscam[torch.isnan(cam)]=0# Resize the CAMcam=torch.nn.functional.interpolate(cam.unsqueeze(1),input_tensor.shape[-2:],mode="bilinear")# Create the explanation map & get the new probswithtorch.inference_mode():masked_probs=self._get_probs(cam*input_tensor)masked_probs=(masked_probs[:,class_idx]ifisinstance(class_idx,int)elsemasked_probs.gather(1,preds.unsqueeze(1)).squeeze(1))# Drop (avoid division by zero)drop=torch.relu(probs-masked_probs).div(probs+1e-7)# Increaseincrease=probs<masked_probsself.cam_extractor._hooks_enabled=Trueself.drop+=drop.sum().item()self.increase+=increase.sum().item()self.total+=input_tensor.shape[0]
[docs]defsummary(self)->Dict[str,float]:"""Computes the aggregated metrics Returns: a dictionary with the average drop and the increase in confidence """ifself.total==0:raiseAssertionError("you need to update the metric before getting the summary")return{"avg_drop":self.drop/self.total,"conf_increase":self.increase/self.total,}