Utilities

Interpretable Input

class captum.attr.InterpretableInput[source]

InterpretableInput is an adapter for different kinds of model inputs to work in Captum’s attribution methods. Generally, attribution methods of Captum assume the inputs are numerical PyTorch tensors whose 1st dimension must be batch size and each index in the rest of dimensions is an interpretable feature. But this is not always true in practice. First, the model may take inputs of formats other than tensor that also require attributions. For example, a model with encapsulated tokenizer can directly take string as input. Second, what is considered as an interpretable feature always depends on the actual application and the user’s desire. For example, the interpretable feature of an image tensor can either be each pixel or some segments. For text, users may see the entire string as one interpretable feature or view each word as one interpretable feature. This class provides a place to define what is the actual model input and the corresponding interpretable format for attribution, and the transformation between them. It serves as a common interface to be used inthe attribution methods to make Captum understand how to perturb various inputs.

The concept Interpretable Input mainly comes from the following two papers:

“Why Should I Trust You?”: Explaining the Predictions of Any Classifier

A Unified Approach to Interpreting Model Predictions

which is also referred to as interpretable representation or simplified input. It can be represented as a mapping function:

\[x = h_x(x') \]

where \(x\) is the model input, which can be anything that the model consumes; \(x'\) is the interpretable input used in the attribution algorithms (it must be a PyTorch tensor in Captum), which is often binary indicating the “presence” or “absence”; \(h_x\) is the transformer. It is supposed to work with perturbation-based attribution methods, but if \(h_x\) is differentiable, it may also be used in gradient-based methods.

InterpretableInput is the abstract class defining the interface. Captum provides the child implementations for some common input formats, like text and sparse features. Users can inherit this class to create other types of customized input.

(We expect to support InterpretableInput in all attribution methods, but it is only allowed in certain attribution classes like LLMAttribution for now.)

format_attr(itp_attr)[source]

Format the attribution of the interpretable feature if needed. The way of formatting depends on the specific interpretable input type. A common use is if the interpretable features are the mask groups of the raw input elements, the attribution of the interpretable features can be scattered back to the model input shape.

Parameters:

itp_attr (Tensor) – attributions of the interpretable features

Returns:

formatted attribution

Return type:

attr (Tensor)

to_model_input(itp_tensor=None)[source]

Get the (perturbed) input in the format required by the model based on the given (perturbed) interpretable representation.

Parameters:

itp_tensor (Tensor, optional) – tensor of the interpretable representation of this input. If it is None, assume the interpretable representation is pristine and return the original model input Default: None.

Returns:

model input passed to the forward function

Return type:

model_input (Any)

to_tensor()[source]

Return the interpretable representation of this input as a tensor

Returns:

interpretable tensor

Return type:

itp_tensor (Tensor)

class captum.attr.TextTemplateInput(template, values, baselines=None, mask=None)[source]

TextTemplateInput is an implementation of InterpretableInput for text inputs, whose interpretable features are certain segments (e.g., words, phrases) of the text. It takes a template string (or function) to define the feature segmentats of the input text. Its input format to the model will be the completed text, while its interpretable representation will be a binary tensor of the number of the segment features whose values indicates if the feature is “presence” or “absence”.

Parameters:
  • template (str or Callable) – template string or function that takes the text segments and format them into the text input for the model

  • values (List[str] or Dict[str, str]) – the values of the segments. it is the input to the template.

  • baselines (List[str] or Dict[str, str] or Callable or None, optional) – the baseline values for the segment features. If it is None, emptry string will be used as the baseline. Default: None

  • mask (List[int] or Dict[str, int] or None, optional) – the mask to group the segment features. It must be in the same format as the values and assign each segment a mask index. Segments with the same index will be seen as a single interpretable feature, which means they must be perturbed together and end with same attributions. Default: None

Examples:

>>> text_inp = TextTemplateInput(
>>>     template="{} feels {} right now",
>>>     values=["He", "depressed"],
>>>     baselines=["It", "neutral"],
>>> )
>>>
>>> text_inp.to_tensor()
>>> # torch.tensor([[1, 1]])
>>>
>>> text_inp.to_model_input(torch.tensor([[0, 1]]))
>>> # "It feels depressed right now"
format_attr(itp_attr)[source]

Format the attribution of the interpretable feature if needed. The way of formatting depends on the specific interpretable input type. A common use is if the interpretable features are the mask groups of the raw input elements, the attribution of the interpretable features can be scattered back to the model input shape.

Parameters:

itp_attr (Tensor) – attributions of the interpretable features

Returns:

formatted attribution

Return type:

attr (Tensor)

to_model_input(perturbed_tensor=None)[source]

Get the (perturbed) input in the format required by the model based on the given (perturbed) interpretable representation.

Parameters:

itp_tensor (Tensor, optional) – tensor of the interpretable representation of this input. If it is None, assume the interpretable representation is pristine and return the original model input Default: None.

Returns:

model input passed to the forward function

Return type:

model_input (Any)

to_tensor()[source]

Return the interpretable representation of this input as a tensor

Returns:

interpretable tensor

Return type:

itp_tensor (Tensor)

class captum.attr.TextTokenInput(text, tokenizer, baselines=0, skip_tokens=None)[source]

TextTokenInput is an implementation of InterpretableInput for text inputs, whose interpretable features are the tokens of the text with respect to a given tokenizer. It is initiated with the string form of the input text and the corresponding tokenizer. Its input format to the model will be the tokenized id tensor, while its interpretable representation will be a binary tensor of the tokens whose values indicates if the token is “presence” or “absence”.

Parameters:
  • text (str) – text string for the model

  • tokenizer (Tokenizer) – tokenizer of the language model

  • baselines (int or str, optional) – the baseline value for the tokens. It can be a string of the baseline token or an integer of the baseline token id. Common choices include unknown token or padding token. The default value is 0, which is commonly used for unknown token. Default: 0

  • skip_tokens (List[int] or List[str], optional) – the tokens to skip in the the input’s interpretable representation. Use this argument to define uninterested tokens, commonly like special tokens, e.g., sos, and unk. It can be a list of strings of the tokens or a list of integers of the token ids. Default: None

Examples:

>>> text_inp = TextTokenInput("This is a test.", tokenizer)
>>>
>>> text_inp.to_tensor()
>>> # the shape dependens on the tokenizer
>>> # assuming it is broken into ["<s>", "This", "is", "a", "test", "."],
>>> # torch.tensor([[1, 6]])
>>>
>>> text_inp.to_model_input(torch.tensor([[0, 1]]))
>>> # torch.tensor([[1, 6]])
format_attr(itp_attr)[source]

Format the attribution of the interpretable feature if needed. The way of formatting depends on the specific interpretable input type. A common use is if the interpretable features are the mask groups of the raw input elements, the attribution of the interpretable features can be scattered back to the model input shape.

Parameters:

itp_attr (Tensor) – attributions of the interpretable features

Returns:

formatted attribution

Return type:

attr (Tensor)

to_model_input(perturbed_tensor=None)[source]

Get the (perturbed) input in the format required by the model based on the given (perturbed) interpretable representation.

Parameters:

itp_tensor (Tensor, optional) – tensor of the interpretable representation of this input. If it is None, assume the interpretable representation is pristine and return the original model input Default: None.

Returns:

model input passed to the forward function

Return type:

model_input (Any)

to_tensor()[source]

Return the interpretable representation of this input as a tensor

Returns:

interpretable tensor

Return type:

itp_tensor (Tensor)

Visualization

captum.attr.visualization.visualize_image_attr(attr, original_image=None, method='heat_map', sign='absolute_value', plt_fig_axis=None, outlier_perc=2, cmap=None, alpha_overlay=0.5, show_colorbar=False, title=None, fig_size=(6, 6), use_pyplot=True)

Visualizes attribution for a given image by normalizing attribution values of the desired sign (positive, negative, absolute value, or all) and displaying them using the desired mode in a matplotlib figure.

Parameters:
  • attr (numpy.ndarray) – Numpy array corresponding to attributions to be visualized. Shape must be in the form (H, W, C), with channels as last dimension. Shape must also match that of the original image if provided.

  • original_image (numpy.ndarray, optional) – Numpy array corresponding to original image. Shape must be in the form (H, W, C), with channels as the last dimension. Image can be provided either with float values in range 0-1 or int values between 0-255. This is a necessary argument for any visualization method which utilizes the original image. Default: None

  • method (str, optional) –

    Chosen method for visualizing attribution. Supported options are:

    1. heat_map - Display heat map of chosen attributions

    2. blended_heat_map - Overlay heat map over greyscale version of original image. Parameter alpha_overlay corresponds to alpha of heat map.

    3. original_image - Only display original image.

    4. masked_image - Mask image (pixel-wise multiply) by normalized attribution values.

    5. alpha_scaling - Sets alpha channel of each pixel to be equal to normalized attribution value.

    Default: heat_map

  • sign (str, optional) –

    Chosen sign of attributions to visualize. Supported options are:

    1. positive - Displays only positive pixel attributions.

    2. absolute_value - Displays absolute value of attributions.

    3. negative - Displays only negative pixel attributions.

    4. all - Displays both positive and negative attribution values. This is not supported for masked_image or alpha_scaling modes, since signed information cannot be represented in these modes.

    Default: absolute_value

  • plt_fig_axis (tuple, optional) – Tuple of matplotlib.pyplot.figure and axis on which to visualize. If None is provided, then a new figure and axis are created. Default: None

  • outlier_perc (float or int, optional) – Top attribution values which correspond to a total of outlier_perc percentage of the total attribution are set to 1 and scaling is performed using the minimum of these values. For sign=`all`, outliers and scale value are computed using absolute value of attributions. Default: 2

  • cmap (str, optional) – String corresponding to desired colormap for heatmap visualization. This defaults to “Reds” for negative sign, “Blues” for absolute value, “Greens” for positive sign, and a spectrum from red to green for all. Note that this argument is only used for visualizations displaying heatmaps. Default: None

  • alpha_overlay (float, optional) – Alpha to set for heatmap when using blended_heat_map visualization mode, which overlays the heat map over the greyscaled original image. Default: 0.5

  • show_colorbar (bool, optional) – Displays colorbar for heatmap below the visualization. If given method does not use a heatmap, then a colormap axis is created and hidden. This is necessary for appropriate alignment when visualizing multiple plots, some with colorbars and some without. Default: False

  • title (str, optional) – Title string for plot. If None, no title is set. Default: None

  • fig_size (tuple, optional) – Size of figure created. Default: (6,6)

  • use_pyplot (bool, optional) – If true, uses pyplot to create and show figure and displays the figure after creating. If False, uses Matplotlib object oriented API and simply returns a figure object without showing. Default: True.

Returns:

  • figure (matplotlib.pyplot.figure):

    Figure object on which visualization is created. If plt_fig_axis argument is given, this is the same figure provided.

  • axis (matplotlib.pyplot.axis):

    Axis object on which visualization is created. If plt_fig_axis argument is given, this is the same axis provided.

Return type:

2-element tuple of figure, axis

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> net = ImageClassifier()
>>> ig = IntegratedGradients(net)
>>> # Computes integrated gradients for class 3 for a given image .
>>> attribution, delta = ig.attribute(orig_image, target=3)
>>> # Displays blended heat map visualization of computed attributions.
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
captum.attr.visualization.visualize_image_attr_multiple(attr, original_image, methods, signs, titles=None, fig_size=(8, 6), use_pyplot=True, **kwargs)

Visualizes attribution using multiple visualization methods displayed in a 1 x k grid, where k is the number of desired visualizations.

Parameters:
  • attr (numpy.ndarray) – Numpy array corresponding to attributions to be visualized. Shape must be in the form (H, W, C), with channels as last dimension. Shape must also match that of the original image if provided.

  • original_image (numpy.ndarray, optional) – Numpy array corresponding to original image. Shape must be in the form (H, W, C), with channels as the last dimension. Image can be provided either with values in range 0-1 or 0-255. This is a necessary argument for any visualization method which utilizes the original image.

  • methods (list[str]) – List of strings of length k, defining method for each visualization. Each method must be a valid string argument for method to visualize_image_attr.

  • signs (list[str]) – List of strings of length k, defining signs for each visualization. Each sign must be a valid string argument for sign to visualize_image_attr.

  • titles (list[str], optional) – List of strings of length k, providing a title string for each plot. If None is provided, no titles are added to subplots. Default: None

  • fig_size (tuple, optional) – Size of figure created. Default: (8, 6)

  • use_pyplot (bool, optional) – If true, uses pyplot to create and show figure and displays the figure after creating. If False, uses Matplotlib object oriented API and simply returns a figure object without showing. Default: True.

  • **kwargs (Any, optional) – Any additional arguments which will be passed to every individual visualization. Such arguments include show_colorbar, alpha_overlay, cmap, etc.

Returns:

  • figure (matplotlib.pyplot.figure):

    Figure object on which visualization is created. If plt_fig_axis argument is given, this is the same figure provided.

  • axis (matplotlib.pyplot.axis):

    Axis object on which visualization is created. If plt_fig_axis argument is given, this is the same axis provided.

Return type:

2-element tuple of figure, axis

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> net = ImageClassifier()
>>> ig = IntegratedGradients(net)
>>> # Computes integrated gradients for class 3 for a given image .
>>> attribution, delta = ig.attribute(orig_image, target=3)
>>> # Displays original image and heat map visualization of
>>> # computed attributions side by side.
>>> _ = visualize_image_attr_multiple(attribution, orig_image,
>>>                     ["original_image", "heat_map"], ["all", "positive"])
captum.attr.visualization.visualize_timeseries_attr(attr, data, x_values=None, method='overlay_individual', sign='absolute_value', channel_labels=None, channels_last=True, plt_fig_axis=None, outlier_perc=2, cmap=None, alpha_overlay=0.7, show_colorbar=False, title=None, fig_size=(6, 6), use_pyplot=True, **pyplot_kwargs)

Visualizes attribution for a given timeseries data by normalizing attribution values of the desired sign (positive, negative, absolute value, or all) and displaying them using the desired mode in a matplotlib figure.

Parameters:
  • attr (numpy.ndarray) – Numpy array corresponding to attributions to be visualized. Shape must be in the form (N, C) with channels as last dimension, unless channels_last is set to True. Shape must also match that of the timeseries data.

  • data (numpy.ndarray) – Numpy array corresponding to the original, equidistant timeseries data. Shape must be in the form (N, C) with channels as last dimension, unless channels_last is set to true.

  • x_values (numpy.ndarray, optional) – Numpy array corresponding to the points on the x-axis. Shape must be in the form (N, ). If not provided, integers from 0 to N-1 are used. Default: None

  • method (str, optional) –

    Chosen method for visualizing attributions overlaid onto data. Supported options are:

    1. overlay_individual - Plot each channel individually in

      a separate panel, and overlay the attributions for each channel as a heat map. The alpha_overlay parameter controls the alpha of the heat map.

    2. overlay_combined - Plot all channels in the same panel,

      and overlay the average attributions as a heat map.

    3. colored_graph - Plot each channel in a separate panel,

      and color the graphs according to the attribution values. Works best with color maps that does not contain white or very bright colors.

    Default: overlay_individual

  • sign (str, optional) –

    Chosen sign of attributions to visualize. Supported options are:

    1. positive - Displays only positive pixel attributions.

    2. absolute_value - Displays absolute value of

      attributions.

    3. negative - Displays only negative pixel attributions.

    4. all - Displays both positive and negative attribution

      values.

    Default: absolute_value

  • channel_labels (list[str], optional) – List of labels corresponding to each channel in data. Default: None

  • channels_last (bool, optional) – If True, data is expected to have channels as the last dimension, i.e. (N, C). If False, data is expected to have channels first, i.e. (C, N). Default: True

  • plt_fig_axis (tuple, optional) – Tuple of matplotlib.pyplot.figure and axis on which to visualize. If None is provided, then a new figure and axis are created. Default: None

  • outlier_perc (float or int, optional) – Top attribution values which correspond to a total of outlier_perc percentage of the total attribution are set to 1 and scaling is performed using the minimum of these values. For sign=`all`, outliers and scale value are computed using absolute value of attributions. Default: 2

  • cmap (str, optional) – String corresponding to desired colormap for heatmap visualization. This defaults to “Reds” for negative sign, “Blues” for absolute value, “Greens” for positive sign, and a spectrum from red to green for all. Note that this argument is only used for visualizations displaying heatmaps. Default: None

  • alpha_overlay (float, optional) – Alpha to set for heatmap when using blended_heat_map visualization mode, which overlays the heat map over the greyscaled original image. Default: 0.7

  • show_colorbar (bool) – Displays colorbar for heat map below the visualization.

  • title (str, optional) – Title string for plot. If None, no title is set. Default: None

  • fig_size (tuple, optional) – Size of figure created. Default: (6,6)

  • use_pyplot (bool) – If true, uses pyplot to create and show figure and displays the figure after creating. If False, uses Matplotlib object oriented API and simply returns a figure object without showing. Default: True.

  • pyplot_kwargs – Keyword arguments forwarded to plt.plot, for example linewidth=3, color=’black’, etc

Returns:

  • figure (matplotlib.pyplot.figure):

    Figure object on which visualization is created. If plt_fig_axis argument is given, this is the same figure provided.

  • axis (matplotlib.pyplot.axis):

    Axis object on which visualization is created. If plt_fig_axis argument is given, this is the same axis provided.

Return type:

2-element tuple of figure, axis

Examples:

>>> # Classifier takes input of shape (batch, length, channels)
>>> model = Classifier()
>>> dl = DeepLift(model)
>>> attribution = dl.attribute(data, target=0)
>>> # Pick the first sample and plot each channel in data in a separate
>>> # panel, with attributions overlaid
>>> visualize_timeseries_attr(attribution[0], data[0], "overlay_individual")

Interpretable Embeddings

class captum.attr.InterpretableEmbeddingBase(embedding, full_name)[source]

Since some embedding vectors, e.g. word are created and assigned in the embedding layers of Pytorch models we need a way to access those layers, generate the embeddings and subtract the baseline. To do so, we separate embedding layers from the model, compute the embeddings separately and do all operations needed outside of the model. The original embedding layer is being replaced by InterpretableEmbeddingBase layer which passes already precomputed embedding vectors to the layers below.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(*inputs, **kwargs)[source]

The forward function of a wrapper embedding layer that takes and returns embedding layer. It allows embeddings to be created outside of the model and passes them seamlessly to the preceding layers of the model.

Parameters:
  • *inputs (Any, optional) – A sequence of inputs arguments that the forward function takes. Since forward functions can take any type and number of arguments, this will ensure that we can execute the forward pass using interpretable embedding layer. Note that if inputs are specified, it is assumed that the first argument is the embedding tensor generated using the self.embedding layer using all input arguments provided in inputs and kwargs.

  • **kwargs (Any, optional) – Similar to inputs we want to make sure that our forward pass supports arbitrary number and type of key-value arguments. If inputs is not provided, kwargs must be provided and the first argument corresponds to the embedding tensor generated using the self.embedding. Note that we make here an assumption here that kwargs is an ordered dict which is new in python 3.6 and is not guaranteed that it will consistently remain that way in the newer versions. In case current implementation doesn’t work for special use cases, it is encouraged to override InterpretableEmbeddingBase and address those specifics in descendant classes.

Returns:

Returns a tensor which is the same as first argument passed to the forward function. It passes pre-computed embedding tensors to lower layers without any modifications.

Return type:

embedding_tensor (Tensor)

indices_to_embeddings(*input, **kwargs)[source]

Maps indices to corresponding embedding vectors. E.g. word embeddings

Parameters:
  • *input (Any, optional) – This can be a tensor(s) of input indices or any other variable necessary to comput the embeddings. A typical example of input indices are word or token indices.

  • **kwargs (Any, optional) – Similar to input this can be any sequence of key-value arguments necessary to compute final embedding tensor.

Returns:

A tensor of word embeddings corresponding to the indices specified in the input

Return type:

tensor

captum.attr.configure_interpretable_embedding_layer(model, embedding_layer_name='embedding')[source]

This method wraps a model’s embedding layer with an interpretable embedding layer that allows us to access the embeddings through their indices.

Parameters:
  • model (torch.nn.Module) – An instance of PyTorch model that contains embeddings.

  • embedding_layer_name (str, optional) – The name of the embedding layer in the model that we would like to make interpretable.

Returns:

An instance of

InterpretableEmbeddingBase embedding layer that wraps model’s embedding layer that is being accessed through embedding_layer_name.

Return type:

interpretable_emb (InterpretableEmbeddingBase)

Examples:

>>> # Let's assume that we have a DocumentClassifier model that
>>> # has a word embedding layer named 'embedding'.
>>> # To make that layer interpretable we need to execute the
>>> # following command:
>>> net = DocumentClassifier()
>>> interpretable_emb = configure_interpretable_embedding_layer(net,
>>>    'embedding')
>>> # then we can use interpretable embedding to convert our
>>> # word indices into embeddings.
>>> # Let's assume that we have the following word indices
>>> input_indices = torch.tensor([1, 0, 2])
>>> # we can access word embeddings for those indices with the command
>>> # line stated below.
>>> input_emb = interpretable_emb.indices_to_embeddings(input_indices)
>>> # Let's assume that we want to apply integrated gradients to
>>> # our model and that target attribution class is 3
>>> ig = IntegratedGradients(net)
>>> attribution = ig.attribute(input_emb, target=3)
>>> # after we finish the interpretation we need to remove
>>> # interpretable embedding layer with the following command:
>>> remove_interpretable_embedding_layer(net, interpretable_emb)
captum.attr.remove_interpretable_embedding_layer(model, interpretable_emb)[source]

Removes interpretable embedding layer and sets back original embedding layer in the model.

Parameters:
  • model (torch.nn.Module) – An instance of PyTorch model that contains embeddings

  • interpretable_emb (InterpretableEmbeddingBase) – An instance of InterpretableEmbeddingBase that was originally created in configure_interpretable_embedding_layer function and has to be removed after interpretation is finished.

Return type:

None

Examples:

>>> # Let's assume that we have a DocumentClassifier model that
>>> # has a word embedding layer named 'embedding'.
>>> # To make that layer interpretable we need to execute the
>>> # following command:
>>> net = DocumentClassifier()
>>> interpretable_emb = configure_interpretable_embedding_layer(net,
>>>    'embedding')
>>> # then we can use interpretable embedding to convert our
>>> # word indices into embeddings.
>>> # Let's assume that we have the following word indices
>>> input_indices = torch.tensor([1, 0, 2])
>>> # we can access word embeddings for those indices with the command
>>> # line stated below.
>>> input_emb = interpretable_emb.indices_to_embeddings(input_indices)
>>> # Let's assume that we want to apply integrated gradients to
>>> # our model and that target attribution class is 3
>>> ig = IntegratedGradients(net)
>>> attribution = ig.attribute(input_emb, target=3)
>>> # after we finish the interpretation we need to remove
>>> # interpretable embedding layer with the following command:
>>> remove_interpretable_embedding_layer(net, interpretable_emb)

Token Reference Base

class captum.attr.TokenReferenceBase(reference_token_idx=0)[source]

A base class for creating reference (aka baseline) tensor for a sequence of tokens. A typical example of such token is PAD. Users need to provide the index of the reference token in the vocabulary as an argument to TokenReferenceBase class.

generate_reference(sequence_length, device)[source]

Generated reference tensor of given sequence_length using reference_token_idx.

Parameters:
  • sequence_length (int) – The length of the reference sequence

  • device (torch.device) – The device on which the reference tensor will be created.

Returns:

A sequence of reference token with shape:

[sequence_length]

Return type:

tensor

Linear Models

class captum._utils.models.model.Model[source]

Abstract Class to describe the interface of a trainable model to be used within the algorithms of captum.

Please note that this is an experimental feature.

abstract fit(train_data, **kwargs)[source]

Override this method to actually train your model.

The specification of the dataloader will be supplied by the algorithm you are using within captum. This will likely be a supervised learning task, thus you should expect batched (x, y) pairs or (x, y, w) triples.

Parameters:

train_data (DataLoader) – The data to train on

Return type:

Optional[Dict[str, Union[int, float, Tensor]]]

Returns:

Optional statistics about training, e.g. iterations it took to train, training loss, etc.

abstract representation()[source]

Returns the underlying representation of the interpretable model. For a linear model this is simply a tensor (the concatenation of weights and bias). For something slightly more complicated, such as a decision tree, this could be the nodes of a decision tree.

Return type:

Tensor

Returns:

A Tensor describing the representation of the model.

class captum._utils.models.linear_model.SkLearnLinearModel(sklearn_module, **kwargs)[source]

Factory class to construct a LinearModel with sklearn training method.

Please note that this assumes:

  1. You have sklearn and numpy installed

  2. The dataset can fit into memory

SkLearn support does introduce some slight overhead as we convert the tensors to numpy and then convert the resulting trained model to a LinearModel object. However, this conversion should be negligible.

Parameters:
  • sklearn_module (str) –

    The module under sklearn to construct and use for training, e.g. use “svm.LinearSVC” for an SVM or “linear_model.Lasso” for Lasso.

    There are factory classes defined for you for common use cases, such as SkLearnLasso.

  • kwargs – The kwargs to pass to the construction of the sklearn model

fit(train_data, **kwargs)[source]
Parameters:
  • train_data (DataLoader) – Train data to use

  • kwargs – Arguments to feed to .fit method for sklearn

class captum._utils.models.linear_model.SkLearnLinearRegression(**kwargs)[source]

Factory class. Trains a model with sklearn.linear_model.LinearRegression.

Any arguments provided to the sklearn constructor can be provided as kwargs here.

fit(train_data, **kwargs)[source]
Parameters:
  • train_data (DataLoader) – Train data to use

  • kwargs – Arguments to feed to .fit method for sklearn

class captum._utils.models.linear_model.SkLearnLasso(**kwargs)[source]

Factory class. Trains a LinearModel model with sklearn.linear_model.Lasso. You will need sklearn version >= 0.23 to support sample weights.

fit(train_data, **kwargs)[source]
Parameters:
  • train_data (DataLoader) – Train data to use

  • kwargs – Arguments to feed to .fit method for sklearn

class captum._utils.models.linear_model.SkLearnRidge(**kwargs)[source]

Factory class. Trains a model with sklearn.linear_model.Ridge.

Any arguments provided to the sklearn constructor can be provided as kwargs here.

fit(train_data, **kwargs)[source]
Parameters:
  • train_data (DataLoader) – Train data to use

  • kwargs – Arguments to feed to .fit method for sklearn

class captum._utils.models.linear_model.SGDLinearModel(**kwargs)[source]

Factory class. Construct a a LinearModel with the sgd_train_linear_model as the train method

Parameters:

kwargs – Arguments send to self._construct_model_params after self.fit is called. Please refer to that method for parameter documentation.

class captum._utils.models.linear_model.SGDLasso(**kwargs)[source]

Factory class to train a LinearModel with SGD (sgd_train_linear_model) whilst setting appropriate parameters to optimize for ridge regression loss. This optimizes L2 loss + alpha * L1 regularization.

Please note that with SGD it is not guaranteed that weights will converge to 0.

fit(train_data, **kwargs)[source]

Calls self.train_fn

class captum._utils.models.linear_model.SGDRidge(**kwargs)[source]

Factory class to train a LinearModel with SGD (sgd_train_linear_model) whilst setting appropriate parameters to optimize for ridge regression loss. This optimizes L2 loss + alpha * L2 regularization.

fit(train_data, **kwargs)[source]

Calls self.train_fn

Baselines

class captum.attr.ProductBaselines(baseline_values)[source]

A Callable Baselines class that returns a sample from the Cartesian product of the inputs’ available baselines.

Parameters:

baseline_values (List or Dict) – A list or dict of lists containing the possible values for each feature. If a dict is provided, the keys can a string of the feature name and the values is a list of available baselines. The keys can also be a tuple of strings to group multiple features whose baselines are not independent to each other. If the key is a tuple, the value must be a list of tuples of the corresponding values.