Captum
  • Docs
  • Tutorials
  • API Reference
  • GitHub

›

Captum Tutorials

  • Overview

Introduction to Captum

  • Getting started with Captum

Attribution

  • Interpreting text models
  • Intepreting vision with CIFAR
  • Interpreting vision with Pretrained Models
  • Feature ablation on images with ResNet
  • Interpreting multimodal models
  • Interpreting a regression model of California house prices
  • Interpreting semantic segmentation models
  • Using Captum with torch.distributed
  • Interpreting Deep Learning Recommender Models
  • Interpreting vision and text models with LIME
  • Understanding Llama2 with Captum LLM Attribution
  • Interpreting BERT

    • Interpreting question answering with BERT Part 1
    • Interpreting question answering with BERT Part 2

Robustness

  • Applying robustness attacks and metrics to CIFAR model and dataset

Concept

  • TCAV for image classification for googlenet model
  • TCAV for NLP sentiment analysis model

Influential Examples

  • Identifying influential examples and mis-labelled examples with TracInCP

Captum Insight

  • Getting started with Captum Insights
  • Using Captum Insights with multimodal models (VQA)

Model Interpretation for Pretrained Deep Learning Models¶

This notebook demonstrates how to apply model interpretability algorithms on pretrained deep learning models (ResNet, VGG) using a handpicked image and visualizes the attributions for each pixel by overlaying them on the image.

The interpretation algorithms that we use in this notebook are Integrated Gradients (w/ and w/o noise tunnel), GradientShap, Occlusion, and LRP. A noise tunnel allows to smoothen the attributions after adding gaussian noise to each input sample.

Note: Before running this tutorial, please install the torchvision, PIL, and matplotlib packages.

In [2]:
import torch
import torch.nn.functional as F

from PIL import Image

import os
import json
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

import torchvision
from torchvision import models
from torchvision import transforms

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import LRP
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
from captum.attr._utils.lrp_rules import EpsilonRule, GammaRule, Alpha1_Beta0_Rule

1- Loading the model and the dataset¶

Loads pretrained Resnet model and sets it to eval mode

In [4]:
model = models.resnet18(pretrained=True)
model = model.eval()

Downloads the list of classes/labels for ImageNet dataset and reads them into the memory

In [8]:
!wget -P $HOME/.torch/models https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
In [9]:
labels_path = os.getenv("HOME") + '/.torch/models/imagenet_class_index.json'
with open(labels_path) as json_data:
    idx_to_labels = json.load(json_data)

Defines transformers and normalizing functions for the image. It also loads an image from the img/resnet/ folder that will be used for interpretation purposes.

In [12]:
transform = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(224),
 transforms.ToTensor()
])

transform_normalize = transforms.Normalize(
     mean=[0.485, 0.456, 0.406],
     std=[0.229, 0.224, 0.225]
 )

img = Image.open('img/resnet/swan-3299528_1280.jpg')

transformed_img = transform(img)

input = transform_normalize(transformed_img)
input = input.unsqueeze(0)

Predict the class of the input image

In [13]:
output = model(input)
output = F.softmax(output, dim=1)
prediction_score, pred_label_idx = torch.topk(output, 1)

pred_label_idx.squeeze_()
predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
Predicted: goose ( 0.4569333493709564 )

2- Gradient-based attribution¶

Let's compute attributions using Integrated Gradients and visualize them on the image. Integrated gradients computes the integral of the gradients of the output of the model for the predicted class pred_label_idx with respect to the input image pixels along the path from the black image to our input image.

In [14]:
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')

integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(input, target=pred_label_idx, n_steps=200)
Predicted: goose ( 0.4569333493709564 )

Let's visualize the image and corresponding attributions by overlaying the latter on the image.

In [15]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                             np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                             method='heat_map',
                             cmap=default_cmap,
                             show_colorbar=True,
                             sign='positive',
                             outlier_perc=1)

Let us compute attributions using Integrated Gradients and smoothens them across multiple images generated by a noise tunnel. The latter adds gaussian noise with a std equals to one, 10 times (nt_samples=10) to the input. Ultimately, noise tunnel smoothens the attributions across nt_samples noisy samples using smoothgrad_sq technique. smoothgrad_sq represents the mean of the squared attributions across nt_samples samples.

In [16]:
noise_tunnel = NoiseTunnel(integrated_gradients)

attributions_ig_nt = noise_tunnel.attribute(input, nt_samples=10, nt_type='smoothgrad_sq', target=pred_label_idx)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

Finally, let us use GradientShap, a linear explanation model which uses a distribution of reference samples (in this case two images) to explain predictions of the model. It computes the expectation of gradients for an input which was chosen randomly between the input and a baseline. The baseline is also chosen randomly from given baseline distribution.

In [17]:
torch.manual_seed(0)
np.random.seed(0)

gradient_shap = GradientShap(model)

# Defining baseline distribution of images
rand_img_dist = torch.cat([input * 0, input * 1])

attributions_gs = gradient_shap.attribute(input,
                                          n_samples=50,
                                          stdevs=0.0001,
                                          baselines=rand_img_dist,
                                          target=pred_label_idx)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "absolute_value"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

3- Occlusion-based attribution¶

Now let us try a different approach to attribution. We can estimate which areas of the image are critical for the classifier's decision by occluding them and quantifying how the decision changes.

We run a sliding window of size 15x15 (defined via sliding_window_shapes) with a stride of 8 along both image dimensions (a defined via strides). At each location, we occlude the image with a baseline value of 0 which correspondes to a gray patch (defined via baselines).

Note: this computation might take more than one minute to complete, as the model is evaluated at every position of the sliding window.

In [18]:
occlusion = Occlusion(model)

attributions_occ = occlusion.attribute(input,
                                       strides = (3, 8, 8),
                                       target=pred_label_idx,
                                       sliding_window_shapes=(3,15, 15),
                                       baselines=0)

Let us visualize the attribution, focusing on the areas with positive attribution (those that are critical for the classifier's decision):

In [19]:
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2)

The upper part of the goose, especially the beak, seems to be the most critical for the model to predict this class.

We can verify this further by occluding the image using a larger sliding window:

In [20]:
occlusion = Occlusion(model)

attributions_occ = occlusion.attribute(input,
                                       strides = (3, 50, 50),
                                       target=pred_label_idx,
                                       sliding_window_shapes=(3,60, 60),
                                       baselines=0)

_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2)

4- LRP-based attribution¶

Now let's try a different approach called Layer-Wise Relevance Propagation (LRP). It uses a backward propagation mechanism applied sequentially to all layers of the model, to see which neurons contributed to the output. The output score of LRP represents the relevance, decomposed into values for each layer. The decomposition is defined by rules that may vary for each layer.

Initially, we apply a direct implementation of LRP attribution. The default Epsilon-Rule is used for each layer.

Note: We use the VGG16 model instead here since the default rules for LRP are not fine-tuned for ResNet currently.

In [22]:
model = models.vgg16(pretrained=True)
model.eval()
lrp = LRP(model)

attributions_lrp = lrp.attribute(input, 
                                target=pred_label_idx)

Let us visualize the attribution, focusing on the areas with positive attribution (those that are critical for the classifier's decision):

In [24]:
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_lrp.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2)

Now let's play around with changing the propagation rules for the various layers. This is a crucial step to get expressive heatmaps. Captum currently has the following propagation rules implemented: LRP-Epsilon, LRP-0, LRP-Gamma, LRP-Alpha-Beta, and the Identity-Rule.

In the next steps, we list all the layers of VGG16 and assign a rule to each one.

Note: Reference for recommmendations on how to set the rules can be found in Towards best practice in explaining neural network decisions with LRP.

In [25]:
layers = list(model._modules["features"]) + list(model._modules["classifier"])
num_layers = len(layers)

for idx_layer in range(1, num_layers):
    if idx_layer <= 16:
        setattr(layers[idx_layer], "rule", GammaRule())
    elif 17 <= idx_layer <= 30:
        setattr(layers[idx_layer], "rule", EpsilonRule())
    elif idx_layer >= 31:
        setattr(layers[idx_layer], "rule", EpsilonRule(epsilon=0))

lrp = LRP(model)
attributions_lrp = lrp.attribute(input, 
                                target=pred_label_idx)

Let us visualize the new attribution. As we can see in the generated output image, the heatmap shows clearly positive attributions forthe beak of the swan.

In [26]:
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_lrp.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2)
Download Tutorial Jupyter Notebook
Download Tutorial Source Code
Docs
IntroductionGetting StartedTutorialsAPI Reference
Legal
PrivacyTerms
Social
captum
Facebook Open Source
Copyright © 2025 Facebook Inc.