Multi-Modal
Supports interpretability of models across modalities including vision, text, and more.
Built on PyTorch
Supports most types of PyTorch models and can be used with minimal modification to the original neural network.
Extensible
Open source, generic library for interpretability research. Easily implement and benchmark new algorithms.
conda install captum -c pytorch
pip install captum
import numpy as np
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(3, 3)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(3, 2)
# initialize weights and biases
self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
self.lin1.bias = nn.Parameter(torch.zeros(1,3))
self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
self.lin2.bias = nn.Parameter(torch.ones(1,2))
def forward(self, input):
return self.lin2(self.relu(self.lin1(input)))
model = ToyModel()
model.eval()
torch.manual_seed(123)
np.random.seed(123)
input = torch.rand(2, 3)
baseline = torch.zeros(2, 3)
ig = IntegratedGradients(model)
attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
print('IG Attributions:', attributions)
print('Convergence Delta:', delta)
IG Attributions: tensor([[-0.5922, -1.5497, -1.0067],
[ 0.0000, -0.2219, -5.1991]])
Convergence Delta: tensor([2.3842e-07, -4.7684e-07])