This tutorial shows how to apply TCAV, a concept-based model interpretability algorithm, on a classification task using GoogleNet model and imagenet dataset.
More details about the approach can be found here: https://arxiv.org/pdf/1711.11279.pdf
In order to use TCAV, we need to predefine a list of concepts that we want our predictions to be test against.
Concepts are human-understandable, high-level abstractions such as visually represented "stripes" in case of images or tokens such as "female" in case of text. Concepts are formatted and represented as input tensors and do not need to be part of the training or test datasets.
Concepts are incorporated into the importance score computations using Concept Activation Vectors (CAVs). Traditionally, CAVs train linear classifiers and learn decision boundaries between different concepts using the activations of predefined concepts in a NN layer as inputs to the classifier that we train. The vector that is orthogonal to learnt decision boundary and is pointing towards the direction of a concept is the CAV of that concept.
TCAV measures the importance of a concept for a prediction based on the directional sensitivity (derivatives) of a concept in Neural Network (NN) layers. For a given concept and layer it is obtained by aggregating the dot product between CAV for a given concept in a given layer and the gradients of model predictions w.r.t. given layer output. The aggregation can be performed based on either signs or magnitudes of the directional sensitivities of concepts across multiple examples belonging to a certain class. More details about the technique can be found in above mentioned papers.
Note: Before running this tutorial, please install the torchvision, numpy, scipy, sklearn, PIL, and matplotlib packages.
import numpy as np import os, glob, sys import matplotlib.pyplot as plt from PIL import Image from scipy.stats import ttest_ind # ..........torch imports............ import torch import torchvision from torch.utils.data import IterableDataset, DataLoader from torchvision import transforms #.... Captum imports.................. from captum.attr import LayerGradientXActivation, LayerIntegratedGradients from captum.concept import TCAV from captum.concept import Concept from captum.concept._utils.data_iterator import dataset_to_dataloader, CustomIterableDataset from captum.concept._utils.common import concepts_to_str
Let's define image transformation function.
# Method to normalize an image to Imagenet mean and standard deviation def transform(img): return transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] )(img)
Now let's define a few helper functions.
def get_tensor_from_filename(filename): img = Image.open(filename).convert("RGB") return transform(img) def load_image_tensors(class_name, root_path='data/tcav/image/imagenet/', transform=True): path = os.path.join(root_path, class_name) filenames = glob.glob(path + '/*.jpg') tensors =  for filename in filenames: img = Image.open(filename).convert('RGB') tensors.append(transform(img) if transform else img) return tensors
Defining a helper function to load predefined concepts.
assemble_concept function reads the concepts using a directory path where the concepts are residing and constructs concept object.
def assemble_concept(name, id, concepts_path="data/tcav/image/concepts/"): concept_path = os.path.join(concepts_path, name) + "/" dataset = CustomIterableDataset(get_tensor_from_filename, concept_path) concept_iter = dataset_to_dataloader(dataset) return Concept(id=id, name=name, data_iter=concept_iter)
Let's assemble concepts into Concept instances using Concept class and concept images stored in
concepts_path. We will use these concepts later in our experiments.
Below we define five concepts. Three out of five are related to image texture and patterns such as
dotted. The other two represent random concepts. Random concepts contain elements / images that are associated with various possible concepts. Distinct from those defined for
Note that concepts should be created and stored under
data/tcav/image/concepts/ folder in advance.
dotted concepts can be found in broden dataset: https://netdissect.csail.mit.edu/broden1_224, under
images/dtd folder as also described here: https://github.com/tensorflow/tcav/tree/master/tcav/tcav_examples/image_models/imagenet. There are in total 120 images for each of the
Please, download those concept images and place under
dotted folders under
data/tcav/image/concepts/ folder accordingly.
Random type of concepts are uniformly sampled from imagenet dataset. More details on how to download and setup imagenet dataset can he found here: https://github.com/tensorflow/tcav/tree/master/tcav/tcav_examples/image_models/imagenet
We will randomly sample 4 diffent sets of 120 random images from imagenet dataset. Note that these images should be distinct from concept and zebra images since the testing will be performed on zebra images.
Place random images under
random_3 folders under
data/tcav/image/concepts/ folder similar to
concepts_path = "data/tcav/image/concepts/" stripes_concept = assemble_concept("striped", 0, concepts_path=concepts_path) zigzagged_concept = assemble_concept("zigzagged", 1, concepts_path=concepts_path) dotted_concept = assemble_concept("dotted", 2, concepts_path=concepts_path) random_0_concept = assemble_concept("random_0", 3, concepts_path=concepts_path) random_1_concept = assemble_concept("random_1", 4, concepts_path=concepts_path)
Let\'s visualize some samples from those concepts
n_figs = 5 n_concepts = 5 fig, axs = plt.subplots(n_concepts, n_figs + 1, figsize = (25, 4 * n_concepts)) for c, concept in enumerate([stripes_concept, zigzagged_concept, dotted_concept, random_0_concept, random_1_concept]): concept_path = os.path.join(concepts_path, concept.name) + "/" img_files = glob.glob(concept_path + '*') for i, img_file in enumerate(img_files[:n_figs + 1]): if os.path.isfile(img_file): if i == 0: axs[c, i].text(1.0, 0.5, str(concept.name), ha='right', va='center', family='sans-serif', size=24) else: img = plt.imread(img_file) axs[c, i].imshow(img) axs[c, i].axis('off')