This tutorial demonstrates how to apply the TracInCP algorithm for influential examples interpretability from the Captum library. TracInCP calculates the influence score of a given training example on a given test example, which roughly speaking, represents how much higher the loss for the given test example would be if the given training example were removed from the training dataset, and the model re-trained. This functionality can be leveraged towards the following 2 use cases:
TracInCP can be used for any trained Pytorch model for which several model checkpoints are available.
Note: Before running this tutorial, please do the following:
Currently, Captum offers 3 implementations, all of which implement the same API. More specifically, they define an influence
method, which can be used in 2 different modes:
The 3 different implementations are defined in the following classes:
TracInCP
: considers gradients in all specified layers when computing influence scores. Specifying many layers will slow the execution of all 3 modes.TracInCPFast
: In Appendix F of the TracIn paper, they show that if considering only gradients in the last fully-connected layer when computing influence scores, the computation can be done more quickly than naively applying backprop to compute gradients, using a computational trick. TracInCPFast
computes influence scores, considering only the last fully-connected layer, using that trick. TracInCPFast
is useful if you want to reduce the time and memory usage, relative to TracInCP
.TracInCPFastRandProj
: The previous two classes were not meant for "interactive" use, because each call to influence
in influence score mode or top-k most influential mode takes time proportional to the training dataset size. On the other hand, TracInCPFastRandProj
enables "interactive" use, i.e. constant-time calls to influence
for those two modes. The price we pay is that in TracInCPFastRandProj.__init__
, pre-processing is done to store embeddings related to each training example into a nearest-neighbors data structure. This pre-processing takes both time and memory proportional to training dataset size. Furthermore, random projections can be applied to reduce memory usage, at the cost of the influence scores used in those two modes to be only approximately correct. Like TracInCPFast
, this class only considers gradients in the last fully-connected layer, and is useful if you want to reduce the time and memory usage, relative to TracInCP
.%matplotlib inline
%load_ext autoreload
%autoreload 2
import datetime
import glob
import os
import pickle
import warnings
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from captum.influence import TracInCP, TracInCPFast, TracInCPFastRandProj
from sklearn.metrics import auc, roc_curve
from torch.utils.data import DataLoader, Dataset, Subset
warnings.filterwarnings("ignore")
First, we will illustrate the ability of TracInCP to identify influential examples, i.e. use the influence
method in "top-k most influential" mode. To do this, we need 3 components:
net
.Dataset
used to train net
. For this we will use correct_dataset
, the original CIFAR-10 training split.Dataset
from which to select test examples, test_dataset
. For this, we will use the original CIFAR-10 validation split. This will also be useful for monitoring training.net
with correct_dataset
. We will train and save checkpoints.net
¶We will use a relatively simple model from the following tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
We first define the architecture of net
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.relu4 = nn.ReLU()
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x
In the cell below, we initialize net
.
net = Net()
Because both are image datasets we will first define the normalize
and inverse_normalize
transforms to transform from image to input, and input to image, respectively.
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
inverse_normalize = transforms.Compose([
transforms.Normalize(mean = [0., 0., 0.], std = [1/0.5, 1/0.5, 1/0.5]),
transforms.Normalize(mean = [-0.5, -0.5, -0.5], std = [1., 1., 1.]),
])
correct_dataset
¶correct_dataset_path = "data/cifar_10"
correct_dataset = torchvision.datasets.CIFAR10(root=correct_dataset_path, train=True, download=True, transform=normalize)
Files already downloaded and verified
test_dataset
¶This will be the same as correct_dataset
, so that it shares the same path and transform. The only difference is that that it uses the validation split
test_dataset = torchvision.datasets.CIFAR10(root=correct_dataset_path, train=False, download=True, transform=normalize)
Files already downloaded and verified
We will obtain checkpoints by training net
for 26 epochs on correct_dataset
. In general, there should be at least 5 checkpoints, and they can be evenly spaced throughout training, or better yet, be for epochs where the loss decreased a lot.
We first define a training function, which is copied from the above tutorial
def train(net, num_epochs, train_dataloader, test_dataloader, checkpoints_dir, save_every):
start_time = datetime.datetime.now()
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epochs): # loop over the dataset multiple times
epoch_loss = 0.0
running_loss = 0.0
for i, data in enumerate(train_dataloader):
# get the inputs
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if (i + 1) % 100 == 0: # print every 100 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 100))
epoch_loss += running_loss
running_loss = 0.0
if epoch % save_every == 0:
checkpoint_name = "-".join(["checkpoint", str(epoch) + ".pt"])
torch.save(
{
"epoch": epoch,
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": epoch_loss,
},
os.path.join(checkpoints_dir, checkpoint_name),
)
# Calcualate validation accuracy
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in test_dataloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Accuracy of the network on test set at epoch %d: %d %%" % (epoch, 100 * correct / total))
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print("Finished training in %.2f minutes" % total_minutes)
We then define the folder to save checkpoints in. We will need this folder later to run TracInCP algorithms.
correct_dataset_checkpoints_dir = os.path.join("checkpoints", "cifar_10_correct_dataset")
Finally, we train the model, converting correct_dataset
and test_dataset
into DataLoader
s, and saving every 5-th checkpoint.
For this tutorial, we have saved the checkpoints from this training on AWS S3, and you can just download those checkpoints instead of doing time-intensive training. If you want to do training yourself, please set the do_training
flag in the next cell to True
.
num_epochs = 26
do_training = False # change to `True` if you want to do training
if do_training:
train(net, num_epochs, DataLoader(correct_dataset, batch_size=128, shuffle=True), DataLoader(test_dataset, batch_size=128, shuffle=True), correct_dataset_checkpoints_dir, save_every=5)
elif not os.path.exists(correct_dataset_checkpoints_dir):
# this should download the zipped folder of checkpoints from the S3 bucket
# then unzip the folder to produce checkpoints in the folder `checkpoints/cifar_10_correct_dataset`
# this is done if checkpoints do not already exist in the folder
# if the below commands do not work, please manually download and unzip the folder to produce checkpoints in that folder
os.makedirs(correct_dataset_checkpoints_dir)
!wget https://pytorch.s3.amazonaws.com/models/captum/influence-tutorials/cifar_10_correct_dataset.zip -O checkpoints/cifar_10_correct_dataset.zip
!unzip -o checkpoints/cifar_10_correct_dataset.zip -d checkpoints
We define the list of checkpoints, correct_dataset_checkpoint_paths
, to be all checkpoints from training.
correct_dataset_checkpoint_paths = glob.glob(os.path.join(correct_dataset_checkpoints_dir, "*.pt"))
We also define a function that loads a given checkpoint into a given model. This will be useful immediately, as well as for use in all TracInCP implementations. When used in TracInCP implementations, this function should return the learning rate at the checkpoint. However, if that learning rate is not available, it is safe to simply return 1, as we do, because it turns out TracInCP implementations are not sensitive to that learning rate.
def checkpoints_load_func(net, path):
weights = torch.load(path)
net.load_state_dict(weights["model_state_dict"])
return 1.
We first load net
with the last checkpoint so that the predictions we make in the next cell will be for the trained model. We save this last checkpoint as correct_dataset_final_checkpoint
, because it turns out we will re-use this checkpoint later on.
correct_dataset_final_checkpoint = os.path.join(correct_dataset_checkpoints_dir, "-".join(['checkpoint', str(num_epochs - 1) + '.pt']))
checkpoints_load_func(net, correct_dataset_final_checkpoint)
1.0
Now, we define test_examples_features
, the features for a batch of test examples to identify influential examples for, and also store the correct as well as predicted labels.
test_examples_indices = [0,1,2,3]
test_examples_features = torch.stack([test_dataset[i][0] for i in test_examples_indices])
test_examples_predicted_probs, test_examples_predicted_labels = torch.max(F.softmax(net(test_examples_features), dim=1), dim=1)
test_examples_true_labels = torch.Tensor([test_dataset[i][1] for i in test_examples_indices]).long()
Recall from above that there are several implementations of the TracInCP algorithm. In particular, TracInCP
is more time and memory intensive than TracInCPFast
and TracInCPFastRandProj
. For this tutorial, to save time, we will only use TracInCPFast
and TracInCPFastRandProj.
To choose between TracInCPFast
and TracInCPFastRandProj
, recall that TracInCPFastRandProj
is suitable for "interactive" use, when multiple calls to use the influence
method in "influence score" and "top-k most influential" mode will be made. In return for the "interactive" use capability, TracInCPFastRandProj
requires an initial pre-processing, which can be both time and memory intensive. On the other hand, TracInCPFast
does not support "interactive" use, but avoids the initial pre-processing.
TracInCPFast
instance¶We will first illustrate the use of TracInCPFast
, to avoid the initial pre-processing (since we will only call the influence
method once, we will not be taking advantage of "interactive" use capability).
To fully define the TracInCPFast
implementation, several more parameters also need to be defined:
final_fc_layer
: a reference or the name of the last fully-connected layer whose gradients will be used to calculate influence scores. This must be the last layer.loss_fn
: The loss function used in training.batch_size
: The batch size of training data used for calculating influence scores. It does not affect the actual influence scores computed, but can affect the computational efficiency. In particular, the fewer batches needed to iterate through the training data, the faster influence
is in all modes. This is because influence
loads model checkpoints once for each batch. So batch_size
should be set large, but not too large (or else the batches will not fit in memory).vectorize
: Whether to use an experimental feature accelerating Jacobian computation. Only available in PyTorch version >1.6.We are now ready to create the TracInCPFast
instance
tracin_cp_fast = TracInCPFast(
model=net,
final_fc_layer=list(net.children())[-1],
train_dataset=correct_dataset,
checkpoints=correct_dataset_checkpoint_paths,
checkpoints_load_func=checkpoints_load_func,
loss_fn=nn.CrossEntropyLoss(reduction="sum"),
batch_size=2048,
vectorize=False,
)
TracInCPFast
¶Now, we call the influence
method of tracin_cp_fast
to compute the influential examples of the test examples represented by test_examples_features
and test_examples_true_labels
. We need to specify whether we want proponents or opponents via the proponents
boolean argument, and how many influential examples to return per test example via the k
argument. Note that k
must be specified. Otherwise, the "influence score" mode will be run. This call should take < 2 minutes.
Note that we pass the test examples as a single tuple. This is because for all implementations, when we pass a single batch, batch
to the influence
method, we assume that batch[-1]
has the labels for the batch, and model(*(batch[0:-1]))
produces the predictions for the batch, so that batch[0:-1]
contains the features for the batch. This convention is was introduced in a recent API change.
This call returns a namedtuple
with ordered elements (indices, influence_scores)
. indices
is a 2D tensor of shape (test_batch_size, k)
, where test_batch_size
is the number of test examples in test_examples_batch
. influence_scores
is of the same shape, but stores the influence scores of the proponents / opponents for each test example in sorted order. For example, if proponents
is True
, influence_scores[i][j]
is the influence score of the training example with the j
-th most positive influence score on test example i
.
k = 10
start_time = datetime.datetime.now()
proponents_indices, proponents_influence_scores = tracin_cp_fast.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=True
)
opponents_indices, opponents_influence_scores = tracin_cp_fast.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=False
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
"Computed proponents / opponents over a dataset of %d examples in %.2f minutes"
% (len(correct_dataset), total_minutes)
)
Computed proponents / opponents over a dataset of 50000 examples in 1.11 minutes
In order to display results, we define a few helper functions that display a test example, display a set of training examples, as well as a helper transform going from a tensor in the datasets to a tensor suitable for the matplotlib imshow
function, and a mapping from numerical label (i.e. 4) to class (i.e. "cat").
label_to_class = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
imshow_transform = lambda tensor_in_dataset: inverse_normalize(tensor_in_dataset.squeeze()).permute(1, 2, 0)
def display_test_example(example, true_label, predicted_label, predicted_prob, label_to_class):
fig, ax = plt.subplots()
print('true_class:', label_to_class[true_label])
print('predicted_class:', label_to_class[predicted_label])
print('predicted_prob', predicted_prob)
ax.imshow(torch.clip(imshow_transform(example), 0, 1))
plt.show()
def display_training_examples(examples, true_labels, label_to_class, figsize=(10,4)):
fig = plt.figure(figsize=figsize)
num_examples = len(examples)
for i in range(num_examples):
ax = fig.add_subplot(1, num_examples, i+1)
ax.imshow(torch.clip(imshow_transform(examples[i]), 0, 1))
ax.set_title(label_to_class[true_labels[i]])
plt.show()
return fig
def display_proponents_and_opponents(test_examples_batch, proponents_indices, opponents_indices, test_examples_true_labels, test_examples_predicted_labels, test_examples_predicted_probs):
for (
test_example,
test_example_proponents,
test_example_opponents,
test_example_true_label,
test_example_predicted_label,
test_example_predicted_prob,
) in zip(
test_examples_batch,
proponents_indices,
opponents_indices,
test_examples_true_labels,
test_examples_predicted_labels,
test_examples_predicted_probs,
):
print("test example:")
display_test_example(
test_example,
test_example_true_label,
test_example_predicted_label,
test_example_predicted_prob,
label_to_class,
)
print("proponents:")
test_example_proponents_tensors, test_example_proponents_labels = zip(
*[correct_dataset[i] for i in test_example_proponents]
)
display_training_examples(
test_example_proponents_tensors, test_example_proponents_labels, label_to_class, figsize=(20, 8)
)
print("opponents:")
test_example_opponents_tensors, test_example_opponents_labels = zip(
*[correct_dataset[i] for i in test_example_opponents]
)
display_training_examples(
test_example_opponents_tensors, test_example_opponents_labels, label_to_class, figsize=(20, 8)
)
We can display, for each test example, its proponents and opponents
display_proponents_and_opponents(
test_examples_features,
proponents_indices,
opponents_indices,
test_examples_true_labels,
test_examples_predicted_labels,
test_examples_predicted_probs,
)
test example: true_class: cat predicted_class: cat predicted_prob tensor(0.4126, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.5685, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.3574, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: plane predicted_class: ship predicted_prob tensor(0.6398, grad_fn=<UnbindBackward0>)
proponents:
opponents:
We see that the results make intuitive sense. For example, the proponents of a test example that is a cat are all cats, labelled as cats. On the other hand, the opponents are all animals that look somewhat like cats, but are labelled as being other animals (i.e. dogs). Thus the presence of these opponents drives the prediction on the test example away from cat.
TracInCPFastRandProj
instance¶We also define and use a TracInCPFastRandProj
instance to show its pros and cons.
Note that __init__
has 2 new arguments, due to the fact TracInCPFastRandProj
stores embeddings related to each example in the training dataset in a nearest-neighbors data structure.
nearest_neighbors
: This is a nearest-neighbors class used internally to find proponents / opponents quickly (proponents / opponents of a test example are those whose embeddings of a certain kind are similar / dissimilar to those of the text sample, see the TracIn paper for more details). Currently, only a single nearest-neighbors class is offered: AnnoyNearestNeighbors
, which wraps the Annoy library. This class has a single argument: num_trees
. The number of trees to use. Increasing this number gives more accurate computation of nearest neighbors, but requires longer setup time to create the trees, as well as memory.projection_dim
: The embeddings may be too high-dimension and require too much memory. Random projections can be used to reduce the dimension of those embeddings. This argument specifies that dimension (it corresponds to the d
variable in Page 15 of the Appendix of the TracIn paper). In more detail, the embedding is the concatenation of several "checkpoint-embeddings", each of which corresponds to a particular checkpoint. Therefore, the dimension of the embedding is actually projection_dim
times the number of checkpoints used.Note: initialization will take ~10 minutes, so feel free to skip the tutorial parts related to TracInCPFastRandProj
from captum.influence._utils.nearest_neighbors import AnnoyNearestNeighbors
start_time = datetime.datetime.now()
tracin_cp_fast_rand_proj = TracInCPFastRandProj(
model=net,
final_fc_layer=list(net.children())[-1],
train_dataset=correct_dataset,
checkpoints=correct_dataset_checkpoint_paths,
checkpoints_load_func=checkpoints_load_func,
loss_fn=nn.CrossEntropyLoss(reduction="sum"),
batch_size=128,
nearest_neighbors=AnnoyNearestNeighbors(num_trees=100),
projection_dim=100,
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
"Performed pre-processing of a dataset of %d examples in %.2f minutes"
% (len(correct_dataset), total_minutes)
)
Performed pre-processing of a dataset of 50000 examples in 4.98 minutes
TracInCPFastRandProj
¶As before, we can compute the proponents / opponents using the influence
method of this TracInCPFastRandProj
instance. Unlike the TracInCPFast
instance, this computation should be very fast, due to the preprocessing done during initialization.
k = 10
start_time = datetime.datetime.now()
proponents_indices, proponents_influence_scores = tracin_cp_fast_rand_proj.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=True
)
opponents_indices, opponents_influence_scores = tracin_cp_fast_rand_proj.influence(
(test_examples_features, test_examples_true_labels), k=k, proponents=False
)
total_minutes = (datetime.datetime.now() - start_time).total_seconds() / 60.0
print(
"Computed proponents / opponents over a dataset of %d examples in %.2f minutes"
% (len(correct_dataset), total_minutes)
)
Computed proponents / opponents over a dataset of 50000 examples in 0.01 minutes
We can display, for each test example, its proponents and opponents
display_proponents_and_opponents(
test_examples_features,
proponents_indices,
opponents_indices,
test_examples_true_labels,
test_examples_predicted_labels,
test_examples_predicted_probs,
)
test example: true_class: cat predicted_class: cat predicted_prob tensor(0.4126, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.5685, grad_fn=<UnbindBackward0>)
proponents:
opponents:
test example: true_class: ship predicted_class: ship predicted_prob tensor(0.3574, grad_fn=<UnbindBackward0>)
proponents:
opponents: