Insights

Batch

class captum.insights.Batch(inputs, labels, additional_args=None)[source]

Constructs batch of inputs to be attributed and visualized.

Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Batch of inputs for a model. These may be either a Tensor or tuple of tensors. Each tensor must correspond to a feature for AttributionVisualizer, and the corresponding input transform function of the feature is applied to each input tensor prior to passing it to the model. It is assumed that the first dimension of each input tensor corresponds to the number of examples (batch size) and is aligned for all input tensors.

  • labels (Tensor) – Tensor containing correct labels for input examples. This must be a 1D tensor with length matching the first dimension of each input tensor.

  • additional_args (tuple, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. For a tensor, the first dimension of the tensor must correspond to the number of examples.

AttributionVisualizer

class captum.insights.AttributionVisualizer(models, classes, features, dataset, score_func=None, use_label_for_attr=True)[source]
Parameters:
  • models (torch.nn.Module) – One or more PyTorch modules (models) for attribution visualization.

  • classes (list[str]) – List of strings corresponding to the names of classes for classification.

  • features (list[BaseFeature]) – List of BaseFeatures, which correspond to input arguments to the model. Each feature object defines relevant transformations for converting to model input, constructing baselines, and visualizing. The length of the features list should exactly match the number of (tensor) arguments expected by the given model. For instance, an image classifier should only provide a single BaseFeature, while a multimodal classifier may provide a list of features, each corresponding to a different tensor input and potentially different modalities.

  • dataset (Iterable of Batch) – Defines the dataset to visualize attributions for. This must be an iterable of batch objects, each of which may contain multiple input examples.

  • score_func (Callable, optional) – This function is applied to the model output to obtain the score for each class. For instance, this function could be the softmax or final non-linearity of the network, applied to the model output. The indices of the second dimension of the output should correspond to the class names provided. If None, the model outputs are taken directly and assumed to correspond to the class scores. Default: None

  • use_label_for_attr (bool, optional) – If true, the class index is passed to the relevant attribution method. This is necessary in most cases where there is an output neuron corresponding to each class. When the model output is a scalar and class index (e.g. positive, negative) is inferred from the output value, this argument should be False. Default: True

Features

BaseFeature

class captum.insights.features.BaseFeature(name, baseline_transforms, input_transforms, visualization_transform)

All Feature classes extend this class to implement custom visualizations in Insights.

It enforces child classes to implement visualization_type and visualize methods.

Parameters:
  • name (str) – The label of the specific feature. For example, an ImageFeature’s name can be “Photo”.

  • baseline_transforms (list, Callable, optional) – Optional list of callables (e.g. functions) to be called on the input tensor to construct multiple baselines. Currently only one baseline is supported. See IntegratedGradients for more information about baselines.

  • input_transforms (list, Callable, optional) – Optional list of callables (e.g. functions) called on the input tensor sequentially to convert it into the format expected by the model.

  • visualization_transform (Callable, optional) – Optional callable (e.g. function) applied as a postprocessing step of the original input data (before input_transforms) to convert it to a format to be understood by the frontend visualizer as specified in captum/captum/insights/frontend/App.js.

GeneralFeature

class captum.insights.features.GeneralFeature(name, categories)

GeneralFeature is used for non-specified feature visualization in Insights. It can be used for dense or sparse features.

Currently general features are only supported for 2-d tensors, in the format (N, C) where N is the number of samples and C is the number of categories.

Parameters:
  • name (str) – The label of the specific feature. For example, an ImageFeature’s name can be “Photo”.

  • categories (list[str]) – Category labels for the general feature. The order and size should match the second dimension of the data tensor parameter in visualize.

TextFeature

class captum.insights.features.TextFeature(name, baseline_transforms, input_transforms, visualization_transform)

TextFeature is used to visualize text (e.g. sentences) in Insights. It expects the visualization transform to convert the input data (e.g. index to string) to the raw text.

Parameters:
  • name (str) – The label of the specific feature. For example, an ImageFeature’s name can be “Photo”.

  • baseline_transforms (list, Callable, optional) – Optional list of callables (e.g. functions) to be called on the input tensor to construct multiple baselines. Currently only one baseline is supported. See IntegratedGradients for more information about baselines. For text features, a common baseline is a tensor of indices corresponding to PAD with the same size as the input tensor. See TokenReferenceBase for more information.

  • input_transforms (list, Callable, optional) – A list of transforms or transform to be applied to the input. For text, a common transform is to convert the tokenized input tensor into an interpretable embedding. See InterpretableEmbeddingBase and configure_interpretable_embedding_layer() for more information.

  • visualization_transform (Callable, optional) – Optional callable (e.g. function) applied as a postprocessing step of the original input data (before input_transforms) to convert it to a suitable format for visualization. For text features, a common function is to convert the token indices to their corresponding (sub)words.

ImageFeature

class captum.insights.features.ImageFeature(name, baseline_transforms, input_transforms, visualization_transform=None)

ImageFeature is used to visualize image features in Insights. It expects an image in NCHW format. If C has a dimension of 1, its assumed to be a greyscale image. If it has a dimension of 3, its expected to be in RGB format.

Parameters:
  • name (str) – The label of the specific feature. For example, an ImageFeature’s name can be “Photo”.

  • baseline_transforms (list, Callable, optional) – Optional list of callables (e.g. functions) to be called on the input tensor to construct multiple baselines. Currently only one baseline is supported. See IntegratedGradients for more information about baselines.

  • input_transforms (list, Callable, optional) – A list of transforms or transform to be applied to the input. For images, normalization is often applied here.

  • visualization_transform (Callable, optional) – Optional callable (e.g. function) applied as a postprocessing step of the original input data (before input_transforms) to convert it to a format to be visualized.