Concept-based Interpretability

TCAV

class captum.concept.TCAV(model, layers, model_id='default_model_id', classifier=None, layer_attr_method=None, attribute_to_layer_input=False, save_path='./cav/', **classifier_kwargs)[source]

This class implements ConceptInterpreter abstract class using an approach called Testing with Concept Activation Vectors (TCAVs), as described in the paper: https://arxiv.org/abs/1711.11279

TCAV scores for a given layer, a list of concepts and input example are computed using the dot product between prediction’s layer sensitivities for given input examples and Concept Activation Vectors (CAVs) in that same layer.

CAVs are defined as vectors that are orthogonal to the classification boundary hyperplane that separate given concepts in a given layer from each other. For a given layer, CAVs are computed by training a classifier that uses the layer activation vectors for a set of concept examples as input examples and concept ids as corresponding input labels. Trained weights of that classifier represent CAVs.

CAVs are represented as a learned weight matrix with the dimensionality C X F, where: F represents the number of input features in the classifier. C is the number of concepts used for the classification. Concept ids are used as labels for concept examples during the training.

We can use any layer attribution algorithm to compute layer sensitivities of a model prediction. For example, the gradients of an output prediction w.r.t. the outputs of the layer. The CAVs and the Sensitivities (SENS) are used to compute the TCAV score:

  1. TCAV = CAV • SENS, a dot product between those two vectors

The final TCAV score can be computed by aggregating the TCAV scores for each input concept based on the sign or magnitude of the tcav scores.

  1. sign_count_score = | TCAV > 0 | / | TCAV |

  2. magnitude_score = SUM(ABS(TCAV * (TCAV > 0))) / SUM(ABS(TCAV))

Parameters:
  • model (Module) – An instance of pytorch model that is used to compute layer activations and attributions.

  • layers (str or list[str]) – A list of layer name(s) that are used for computing concept activations (cavs) and layer attributions.

  • model_id (str, optional) – A unique identifier for the PyTorch model passed as first argument to the constructor of TCAV class. It is used to store and load activations for given input model and associated layers.

  • classifier (Classifier, optional) – A custom classifier class, such as the Sklearn “linear_model” that allows us to train a model using the activation vectors extracted for a layer per concept. It also allows us to access trained weights of the model and the list of prediction classes.

  • layer_attr_method (LayerAttribution, optional) –

    An instance of a layer attribution algorithm that helps us to compute model prediction sensitivity scores.

    Default: None If layer_attr_method is None, we default it to gradients for the layers using LayerGradientXActivation layer attribution algorithm.

  • save_path (str, optional) – The path for storing CAVs and Activation Vectors (AVs).

  • classifier_kwargs (Any, optional) – Additional arguments such as test_split_ratio that are passed to concept classifier.

Examples::
>>>
>>> # TCAV use example:
>>>
>>> # Define the concepts
>>> stripes = Concept(0, "stripes", striped_data_iter)
>>> random = Concept(1, "random", random_data_iter)
>>>
>>>
>>> mytcav = TCAV(model=imagenet,
>>>     layers=['inception4c', 'inception4d'])
>>>
>>> scores = mytcav.interpret(inputs, [[stripes, random]], target = 0)
>>>
For more thorough examples, please check out TCAV tutorial and test cases.
compute_cavs(experimental_sets, force_train=False, processes=None)[source]

This method computes CAVs for given experiments_sets and layers specified in self.layers instance variable. Internally, it trains a classifier and creates an instance of CAV class using the weights of the trained classifier for each experimental set.

It also allows to compute the CAVs in parallel using python’s multiprocessing API and the number of processes specified in the argument.

Parameters:
  • experimental_sets (list[list[Concept]]) – A list of lists of concept instances for which the cavs will be computed.

  • force_train (bool, optional) – A flag that indicates whether to train the CAVs regardless of whether they are saved or not. Default: False

  • processes (int, optional) – The number of processes to be created when running in multi-processing mode. If processes > 0 then CAV computation will be performed in parallel using multi-processing, otherwise it will be performed sequentially in a single process. Default: None

Returns:

A mapping of concept ids and layers to CAV objects.

If CAVs for the concept_ids-layer pairs are present in the data storage they will be loaded into the memory, otherwise they will be computed using a training process and stored in the data storage that can be configured using save_path input argument.

Return type:

cavs (dict)

generate_activation(layers, concept)[source]

Computes layer activations for the specified concept and the list of layer(s) layers.

Parameters:
  • layers (str or list[str]) – A list of layer names or a layer name that is used to compute layer activations for the specific concept.

  • concept (Concept) – A single Concept object that provides access to concept examples using a data iterator.

Return type:

None

generate_activations(concept_layers)[source]

Computes layer activations for the concepts and layers specified in concept_layers dictionary.

Parameters:

concept_layers (dict[Concept, list[str]]) – Dictionay that maps Concept objects to a list of layer names to generate the activations. Ex.: concept_layers = {“striped”: [‘inception4c’, ‘inception4d’]}

Return type:

None

generate_all_activations()[source]

Computes layer activations for all concepts and layers that are defined in self.layers and self.concepts instance variables.

Return type:

None

interpret(inputs, experimental_sets, target=None, additional_forward_args=None, processes=None, **kwargs)[source]

This method computes magnitude and sign-based TCAV scores for each experimental sets in experimental_sets list. TCAV scores are computed using a dot product between layer attribution scores for specific predictions and CAV vectors.

Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Inputs for which predictions are performed and attributions are computed. If model takes a single tensor as input, a single input tensor should be provided. If model takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately.

  • experimental_sets (list[list[Concept]]) – A list of list of Concept instances.

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which attributions are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single

      integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching

      the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1

      elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of

      examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

  • additional_forward_args (Any, optional) – Extra arguments that are passed to model when computing the attributions for inputs w.r.t. layer output. Default: None

  • processes (int, optional) – The number of processes to be created. if processes is larger than one then CAV computations will be performed in parallel using the number of processes equal to processes. Otherwise, CAV computations will be performed sequential. Default:None

  • **kwargs (Any, optional) – A list of arguments that are passed to layer attribution algorithm’s attribute method. This could be for example n_steps in case of integrated gradients. Default: None

Returns:

A dictionary of sign and magnitude -based tcav scores

for each concept set per layer. The order of TCAV scores in the resulting tensor for each experimental set follows the order in which concepts are passed in experimental_sets input argument.

Return type:

results (dict)

results example::
>>> #
>>> # scores =
>>> # {'0-1':
>>> #     {'inception4c':
>>> #         {'sign_count': tensor([0.5800, 0.4200]),
>>> #          'magnitude': tensor([0.6613, 0.3387])},
>>> #      'inception4d':
>>> #         {'sign_count': tensor([0.6200, 0.3800]),
>>> #           'magnitude': tensor([0.7707, 0.2293])}}),
>>> #  '0-2':
>>> #     {'inception4c':
>>> #         {'sign_count': tensor([0.6200, 0.3800]),
>>> #          'magnitude': tensor([0.6806, 0.3194])},
>>> #      'inception4d':
>>> #         {'sign_count': tensor([0.6400, 0.3600]),
>>> #          'magnitude': tensor([0.6563, 0.3437])}})})
>>> #
load_cavs(concepts)[source]

This function load CAVs as a dictionary of concept ids and layers. CAVs are stored in a directory located under self.save_path path, in .pkl files with the format: <self.save_path>/<concept_ids>-<layer_name>.pkl. Ex.: “/cavs/0-1-2-inception4c.pkl”, where 0, 1 and 2 are concept ids.

It returns a list of layers and a dictionary of concept-layers mapping for the concepts and layer that require CAV computation through training. This can happen if the CAVs aren’t already pre-computed for a given list of concepts and layer.

Parameters:

concepts (list[Concept]) – A list of Concept objects for which we want to load the CAV.

Returns:

A list of layers for which some CAVs still need

to be computed.

concept_layers (dict[concept, layer]): A dictionay of concept-layers

mapping for which we need to perform CAV computation through training.

Return type:

layers (list[layer])

ConceptInterpreter

class captum.concept.ConceptInterpreter(model)[source]

An abstract class that exposes an abstract interpret method that has to be implemented by a specific algorithm for concept-based model interpretability.

Parameters:

model (torch.nn.Module) – An instance of pytorch model.

interpret: Callable

An abstract interpret method that performs concept-based model interpretability and returns the interpretation results in form of tensors, dictionaries or other data structures.

Parameters:

inputs (Tensor or tuple[Tensor, ...]) – Inputs for which concept-based interpretation scores are computed. It can be provided as a single tensor or a tuple of multiple tensors. If multiple input tensors are provided, the batch size (the first dimension of the tensors) must be aligned across all tensors.

Concept

class captum.concept.Concept(id, name, data_iter)[source]

Concepts are human-friendly abstract representations that can be numerically encoded into torch tensors. They can be illustrated as images, text or any other form of representation. In case of images, for example, “stripes” concept can be represented through a number of example images resembling “stripes” in various different contexts. In case of Natural Language Processing, the concept of “happy”, for instance, can be illustrated through a number of adjectives and words that convey happiness.

Parameters:
  • id (int) – The unique identifier of the concept.

  • name (str) – A unique name of the concept.

  • data_iter (DataLoader) – A pytorch DataLoader object that combines a dataset and a sampler, and provides an iterable over a given dataset. Only the input batches are provided by data_iter. Concept ids can be used as labels if necessary. For more information, please check: https://pytorch.org/docs/stable/data.html

Example:

>>> # Creates a Concept object named "striped", with a data_iter
>>> # object to iterate over all files in "./concepts/striped"
>>> concept_name = "striped"
>>> concept_path = os.path.join("./concepts", concept_name) + "/"
>>> concept_iter = dataset_to_dataloader(
>>> get_tensor_from_filename, concepts_path=concept_path)
>>> concept_object = Concept(
        id=0, name=concept_name, data_iter=concept_iter)

Classifier

class captum.concept.Classifier[source]

An abstract class definition of any classifier that allows to train a model and access trained weights of that model.

More specifically the classifier can, for instance, be trained on the activations of a particular layer. Below we can see an example a sklearn linear classifier wrapped by the CustomClassifier which extends Classifier abstract class.

Example:

>>> from sklearn import linear_model
>>>
>>> class CustomClassifier(Classifier):
>>>
>>> def __init__(self):
>>>
>>>     self.lm = linear_model.SGDClassifier(alpha=0.01, max_iter=1000,
>>>                                          tol=1e-3)
>>>
>>> def train_and_eval(self, dataloader):
>>>
>>>     x_train, x_test, y_train, y_test = train_test_split(inputs, labels)
>>>     self.lm.fit(x_train.detach().numpy(), y_train.detach().numpy())
>>>
>>>     preds = torch.tensor(self.lm.predict(x_test.detach().numpy()))
>>>     return {'accs': (preds == y_test).float().mean()}
>>>
>>>
>>> def weights(self):
>>>
>>>     if len(self.lm.coef_) == 1:
>>>         # if there are two concepts, there is only one label.
>>>         # We split it in two.
>>>         return torch.tensor([-1 * self.lm.coef_[0], self.lm.coef_[0]])
>>>     else:
>>>         return torch.tensor(self.lm.coef_)
>>>
>>>
>>> def classes(self):
>>>     return self.lm.classes_
>>>
>>>
abstract classes()[source]

This function returns the list of all classes that are used by the classifier to train the model in the train_and_eval method. The order of returned classes has to match the same order used in the weights matrix returned by the weights method.

Returns:

The list of classes used by the classifier to train the model in the train_and_eval method.

Return type:

classes (list)

abstract train_and_eval(dataloader, **kwargs)[source]

This method is responsible for training a classifier using the data provided through dataloader input arguments. Based on the specific implementation, it may or may not return a statistics about model training and evaluation.

Parameters:
  • dataloader (dataloader) – A dataloader that enables batch-wise access to the inputs and corresponding labels. Dataloader allows us to iterate over the dataset by loading the batches in lazy manner.

  • kwargs (dict) – Named arguments that are used for training and evaluating concept classifier. Default: None

Returns:

a dictionary of statistics about the performance of the model.

For example the accuracy of the model on the test and/or train dataset(s). The user may decide to return None or an empty dictionary if they decide to not return any performance statistics.

Return type:

stats (dict)

abstract weights()[source]

This function returns a C x F tensor weights, where C is the number of classes and F is the number of features.

Returns:

A torch Tensor with the weights resulting from

the model training.

Return type:

weights (Tensor)