Layer Attribution

Layer Conductance

class captum.attr.LayerConductance(forward_func, layer, device_ids=None)[source]

Computes conductance with respect to the given layer. The returned output is in the shape of the layer’s output, showing the total conductance of each hidden layer neuron.

The details of the approach can be found here: https://arxiv.org/abs/1805.12233 https://arxiv.org/abs/1807.09946

Note that this provides the total conductance of each neuron in the layer’s output. To obtain the breakdown of a neuron’s conductance by input features, utilize NeuronConductance instead, and provide the target neuron index.

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module) – Layer for which attributions are computed. Output size of attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to attribution of each neuron in the input or output of this layer.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

attribute(inputs, baselines=None, target=None, additional_forward_args=None, n_steps=50, method='gausslegendre', internal_batch_size=None, return_convergence_delta=False, attribute_to_layer_input=False)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which layer conductance is computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (scalar, Tensor, tuple of scalar, or Tensor, optional) –

    Baselines define the starting point from which integral is computed and can be provided as:

    • a single tensor, if inputs is a single tensor, with exactly the same dimensions as inputs or the first dimension is one and the remaining dimensions match with inputs.

    • a single scalar, if inputs is a single tensor, which will be broadcasted for each input value in input tensor.

    • a tuple of tensors or scalars, the baseline corresponding to each tensor in the inputs’ tuple can be:

      • either a tensor with matching dimensions to corresponding tensor in the inputs’ tuple or the first dimension is one and the remaining dimensions match with the corresponding input tensor.

      • or a scalar, corresponding to a tensor in the inputs’ tuple. This scalar value is broadcasted for corresponding input tensor.

    In the cases when baselines is not provided, we internally use zero scalar corresponding to each input tensor.

    Default: None

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. For a tensor, the first dimension of the tensor must correspond to the number of examples. It will be repeated for each of n_steps along the integrated path. For all other types, the given argument is used for all forward evaluations. Note that attributions are not computed with respect to these arguments. Default: None

  • n_steps (int, optional) – The number of steps used by the approximation method. Default: 50.

  • method (str, optional) – Method for approximating the integral, one of riemann_right, riemann_left, riemann_middle, riemann_trapezoid or gausslegendre. Default: gausslegendre if no method is provided.

  • internal_batch_size (int, optional) – Divides total #steps * #examples data points into chunks of size at most internal_batch_size, which are computed (forward / backward passes) sequentially. internal_batch_size must be at least equal to 2 * #examples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain internal_batch_size / num_devices examples. If internal_batch_size is None, then all evaluations are processed in one batch. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer inputs, otherwise it will be computed with respect to layer outputs. Note that currently it is assumed that either the input or the output of internal layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Conductance of each neuron in given layer input or output. Attributions will always be the same size as the input or output of the given layer, depending on whether we attribute to the inputs or outputs of the layer which is decided by the input flag attribute_to_layer_input. Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned.

  • delta (Tensor, returned if return_convergence_delta=True):

    The difference between the total approximated and true conductance. This is computed using the property that the total sum of forward_func(inputs) - forward_func(baselines) must equal the total sum of the attributions. Delta is calculated per example, meaning that the number of elements in returned delta tensor is equal to the number of examples in inputs.

Return type:

attributions or 2-element tuple of attributions, delta

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx12x32x32.
>>> net = ImageClassifier()
>>> layer_cond = LayerConductance(net, net.conv1)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes layer conductance for class 3.
>>> # attribution size matches layer output, Nx12x32x32
>>> attribution = layer_cond.attribute(input, target=3)
has_convergence_delta()[source]

This method informs the user whether the attribution algorithm provides a convergence delta (aka an approximation error) or not. Convergence delta may serve as a proxy of correctness of attribution algorithm’s approximation. If deriving attribution class provides a compute_convergence_delta method, it should override both compute_convergence_delta and has_convergence_delta methods.

Returns:

Returns whether the attribution algorithm provides a convergence delta (aka approximation error) or not.

Return type:

bool

Layer Activation

class captum.attr.LayerActivation(forward_func, layer, device_ids=None)[source]

Computes activation of selected layer for given input.

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module or list of torch.nn.Module) – Layer or layers for which attributions are computed. Output size of attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to attribution of each neuron in the input or output of this layer. If multiple layers are provided, attributions are returned as a list, each element corresponding to the activations of the corresponding layer.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

attribute(inputs, additional_forward_args=None, attribute_to_layer_input=False)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which layer activation is computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer input, otherwise it will be computed with respect to layer output. Note that currently it is assumed that either the input or the output of internal layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

Returns:

  • attributions (Tensor, tuple[Tensor, …], or list):

    Activation of each neuron in given layer output. Attributions will always be the same size as the output of the given layer. Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned. If multiple layers are provided, attributions are returned as a list, each element corresponding to the activations of the corresponding layer.

Return type:

Tensor or tuple[Tensor, …] or list of attributions

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx12x32x32.
>>> net = ImageClassifier()
>>> layer_act = LayerActivation(net, net.conv1)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes layer activation.
>>> # attribution is layer output, with size Nx12x32x32
>>> attribution = layer_act.attribute(input)

Internal Influence

class captum.attr.InternalInfluence(forward_func, layer, device_ids=None)[source]

Computes internal influence by approximating the integral of gradients for a particular layer along the path from a baseline input to the given input. If no baseline is provided, the default baseline is the zero tensor. More details on this approach can be found here: https://arxiv.org/abs/1802.03788

Note that this method is similar to applying integrated gradients and taking the layer as input, integrating the gradient of the layer with respect to the output.

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module) – Layer for which attributions are computed. Output size of attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to attribution of each neuron in the input or output of this layer.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

attribute(inputs, baselines=None, target=None, additional_forward_args=None, n_steps=50, method='gausslegendre', internal_batch_size=None, attribute_to_layer_input=False)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which internal influence is computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (scalar, Tensor, tuple of scalar, or Tensor, optional) –

    Baselines define a starting point from which integral is computed and can be provided as:

    • a single tensor, if inputs is a single tensor, with exactly the same dimensions as inputs or the first dimension is one and the remaining dimensions match with inputs.

    • a single scalar, if inputs is a single tensor, which will be broadcasted for each input value in input tensor.

    • a tuple of tensors or scalars, the baseline corresponding to each tensor in the inputs’ tuple can be:

      • either a tensor with matching dimensions to corresponding tensor in the inputs’ tuple or the first dimension is one and the remaining dimensions match with the corresponding input tensor.

      • or a scalar, corresponding to a tensor in the inputs’ tuple. This scalar value is broadcasted for corresponding input tensor.

    In the cases when baselines is not provided, we internally use zero scalar corresponding to each input tensor.

    Default: None

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. For a tensor, the first dimension of the tensor must correspond to the number of examples. It will be repeated for each of n_steps along the integrated path. For all other types, the given argument is used for all forward evaluations. Note that attributions are not computed with respect to these arguments. Default: None

  • n_steps (int, optional) – The number of steps used by the approximation method. Default: 50.

  • method (str, optional) – Method for approximating the integral, one of riemann_right, riemann_left, riemann_middle, riemann_trapezoid or gausslegendre. Default: gausslegendre if no method is provided.

  • internal_batch_size (int, optional) – Divides total #steps * #examples data points into chunks of size at most internal_batch_size, which are computed (forward / backward passes) sequentially. internal_batch_size must be at least equal to #examples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain internal_batch_size / num_devices examples. If internal_batch_size is None, then all evaluations are processed in one batch. Default: None

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer inputs, otherwise it will be computed with respect to layer outputs. Note that currently it is assumed that either the input or the output of internal layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Internal influence of each neuron in given layer output. Attributions will always be the same size as the output or input of the given layer depending on whether attribute_to_layer_input is set to False or True respectively. Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned.

Return type:

Tensor or tuple[Tensor, …] of attributions

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx12x32x32.
>>> net = ImageClassifier()
>>> layer_int_inf = InternalInfluence(net, net.conv1)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes layer internal influence.
>>> # attribution size matches layer output, Nx12x32x32
>>> attribution = layer_int_inf.attribute(input)

Layer Gradient X Activation

class captum.attr.LayerGradientXActivation(forward_func, layer, device_ids=None, multiply_by_inputs=True)[source]

Computes element-wise product of gradient and activation for selected layer on given inputs.

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module or list of torch.nn.Module) – Layer or layers for which attributions are computed. Output size of attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to attribution of each neuron in the input or output of this layer. If multiple layers are provided, attributions are returned as a list, each element corresponding to the attributions of the corresponding layer.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

  • multiply_by_inputs (bool, optional) –

    Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in, then this type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found here: https://arxiv.org/abs/1711.06104

    In case of layer gradient x activation, if multiply_by_inputs is set to True, final sensitivity scores are being multiplied by layer activations for inputs.

attribute(inputs, target=None, additional_forward_args=None, attribute_to_layer_input=False)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which attributions are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer input, otherwise it will be computed with respect to layer output. Default: False

Returns:

  • attributions (Tensor, tuple[Tensor, …], or list):

    Product of gradient and activation for each neuron in given layer output. Attributions will always be the same size as the output of the given layer. Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned. If multiple layers are provided, attributions are returned as a list, each element corresponding to the activations of the corresponding layer.

Return type:

Tensor or tuple[Tensor, …] or list of attributions

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx12x32x32.
>>> net = ImageClassifier()
>>> layer_ga = LayerGradientXActivation(net, net.conv1)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes layer activation x gradient for class 3.
>>> # attribution size matches layer output, Nx12x32x32
>>> attribution = layer_ga.attribute(input, 3)

GradCAM

class captum.attr.LayerGradCam(forward_func, layer, device_ids=None)[source]

Computes GradCAM attribution for chosen layer. GradCAM is designed for convolutional neural networks, and is usually applied to the last convolutional layer.

GradCAM computes the gradients of the target output with respect to the given layer, averages for each output channel (dimension 2 of output), and multiplies the average gradient for each channel by the layer activations. The results are summed over all channels.

Note that in the original GradCAM algorithm described in the paper, ReLU is applied to the output, returning only non-negative attributions. For providing more flexibility to the user, we choose to not perform the ReLU internally by default and return the sign information. To match the original GradCAM algorithm, it is necessary to pass the parameter relu_attributions=True to apply ReLU on the final attributions or alternatively only visualize the positive attributions.

Note: this procedure sums over the second dimension (# of channels), so the output of GradCAM attributions will have a second dimension of 1, but all other dimensions will match that of the layer output.

GradCAM attributions are generally upsampled and can be viewed as a mask to the input, since a convolutional layer output generally matches the input image spatially. This upsampling can be performed using LayerAttribution.interpolate, as shown in the example below.

More details regarding the GradCAM method can be found in the original paper here: https://arxiv.org/abs/1610.02391

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module) – Layer for which attributions are computed. Output size of attribute matches this layer’s output dimensions, except for dimension 2, which will be 1, since GradCAM sums over channels.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

attribute(inputs, target=None, additional_forward_args=None, attribute_to_layer_input=False, relu_attributions=False, attr_dim_summation=True)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which attributions are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attributions with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to the layer input, otherwise it will be computed with respect to layer output. Note that currently it is assumed that either the input or the outputs of internal layers, depending on whether we attribute to the input or output, are single tensors. Support for multiple tensors will be added later. Default: False

  • relu_attributions (bool, optional) – Indicates whether to apply a ReLU operation on the final attribution, returning only non-negative attributions. Setting this flag to True matches the original GradCAM algorithm, otherwise, by default, both positive and negative attributions are returned. Default: False

  • attr_dim_summation (bool, optional) – Indicates whether to sum attributions along dimension 1 (usually channel). The default (True) means to sum along dimension 1. Default: True

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Attributions based on GradCAM method. Attributions will be the same size as the output of the given layer, except for dimension 2, which will be 1 due to summing over channels. Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned.

Return type:

Tensor or tuple[Tensor, …] of attributions

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> # It contains a layer conv4, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx50x8x8.
>>> # It is the last convolution layer, which is the recommended
>>> # use case for GradCAM.
>>> net = ImageClassifier()
>>> layer_gc = LayerGradCam(net, net.conv4)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes layer GradCAM for class 3.
>>> # attribution size matches layer output except for dimension
>>> # 1, so dimensions of attr would be Nx1x8x8.
>>> attr = layer_gc.attribute(input, 3)
>>> # GradCAM attributions are often upsampled and viewed as a
>>> # mask to the input, since the convolutional layer output
>>> # spatially matches the original input image.
>>> # This can be done with LayerAttribution's interpolate method.
>>> upsampled_attr = LayerAttribution.interpolate(attr, (32, 32))

Layer DeepLift

class captum.attr.LayerDeepLift(model, layer, multiply_by_inputs=True)[source]

Implements DeepLIFT algorithm for the layer based on the following paper: Learning Important Features Through Propagating Activation Differences, Avanti Shrikumar, et. al. https://arxiv.org/abs/1704.02685

and the gradient formulation proposed in: Towards better understanding of gradient-based attribution methods for deep neural networks, Marco Ancona, et.al. https://openreview.net/pdf?id=Sy21R9JAW

This implementation supports only Rescale rule. RevealCancel rule will be supported in later releases. Although DeepLIFT’s(Rescale Rule) attribution quality is comparable with Integrated Gradients, it runs significantly faster than Integrated Gradients and is preferred for large datasets.

Currently we only support a limited number of non-linear activations but the plan is to expand the list in the future.

Note: As we know, currently we cannot access the building blocks, of PyTorch’s built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid. Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs with performance similar to built-in ones using TorchScript. More details on how to build custom RNNs can be found here: https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

Parameters:
  • model (nn.Module) – The reference to PyTorch model instance.

  • layer (torch.nn.Module) – Layer for which attributions are computed. The size and dimensionality of the attributions corresponds to the size and dimensionality of the layer’s input or output depending on whether we attribute to the inputs or outputs of the layer.

  • multiply_by_inputs (bool, optional) –

    Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in then that type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found here: https://arxiv.org/abs/1711.06104

    In case of Layer DeepLift, if multiply_by_inputs is set to True, final sensitivity scores are being multiplied by layer activations for inputs - layer activations for baselines. This flag applies only if custom_attribution_func is set to None.

attribute(inputs, baselines=None, target=None, additional_forward_args=None, return_convergence_delta=False, attribute_to_layer_input=False, custom_attribution_func=None)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which layer attributions are computed. If model takes a single tensor as input, a single input tensor should be provided. If model takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (scalar, Tensor, tuple of scalar, or Tensor, optional) –

    Baselines define reference samples that are compared with the inputs. In order to assign attribution scores DeepLift computes the differences between the inputs/outputs and corresponding references. Baselines can be provided as:

    • a single tensor, if inputs is a single tensor, with exactly the same dimensions as inputs or the first dimension is one and the remaining dimensions match with inputs.

    • a single scalar, if inputs is a single tensor, which will be broadcasted for each input value in input tensor.

    • a tuple of tensors or scalars, the baseline corresponding to each tensor in the inputs’ tuple can be:

      • either a tensor with matching dimensions to corresponding tensor in the inputs’ tuple or the first dimension is one and the remaining dimensions match with the corresponding input tensor.

      • or a scalar, corresponding to a tensor in the inputs’ tuple. This scalar value is broadcasted for corresponding input tensor.

    In the cases when baselines is not provided, we internally use zero scalar corresponding to each input tensor.

    Default: None

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to model in order, following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer input, otherwise it will be computed with respect to layer output. Note that currently it is assumed that either the input or the output of internal layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

  • custom_attribution_func (Callable, optional) –

    A custom function for computing final attribution scores. This function can take at least one and at most three arguments with the following signature:

    • custom_attribution_func(multipliers)

    • custom_attribution_func(multipliers, inputs)

    • custom_attribution_func(multipliers, inputs, baselines)

    In case this function is not provided, we use the default logic defined as: multipliers * (inputs - baselines) It is assumed that all input arguments, multipliers, inputs and baselines are provided in tuples of same length. custom_attribution_func returns a tuple of attribution tensors that have the same length as the inputs. Default: None

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Attribution score computed based on DeepLift’s rescale rule with respect to layer’s inputs or outputs. Attributions will always be the same size as the provided layer’s inputs or outputs, depending on whether we attribute to the inputs or outputs of the layer. If the layer input / output is a single tensor, then just a tensor is returned; if the layer input / output has multiple tensors, then a corresponding tuple of tensors is returned.

  • delta (Tensor, returned if return_convergence_delta=True):

    This is computed using the property that the total sum of model(inputs) - model(baselines) must equal the total sum of the attributions computed based on DeepLift’s rescale rule. Delta is calculated per example, meaning that the number of elements in returned delta tensor is equal to the number of examples in input. Note that the logic described for deltas is guaranteed when the default logic for attribution computations is used, meaning that the custom_attribution_func=None, otherwise it is not guaranteed and depends on the specifics of the custom_attribution_func.

Return type:

attributions or 2-element tuple of attributions, delta

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> net = ImageClassifier()
>>> # creates an instance of LayerDeepLift to interpret target
>>> # class 1 with respect to conv4 layer.
>>> dl = LayerDeepLift(net, net.conv4)
>>> input = torch.randn(1, 3, 32, 32, requires_grad=True)
>>> # Computes deeplift attribution scores for conv4 layer and class 3.
>>> attribution = dl.attribute(input, target=1)

Layer DeepLiftShap

class captum.attr.LayerDeepLiftShap(model, layer, multiply_by_inputs=True)[source]

Extends LayerDeepLift and DeepLiftShap algorithms and approximates SHAP values for given input layer. For each input sample - baseline pair it computes DeepLift attributions with respect to inputs or outputs of given layer averages resulting attributions across baselines. Whether to compute the attributions with respect to the inputs or outputs of the layer is defined by the input flag attribute_to_layer_input. More details about the algorithm can be found here:

https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf

Note that the explanation model:

  1. Assumes that input features are independent of one another

  2. Is linear, meaning that the explanations are modeled through

    the additive composition of feature effects.

Although, it assumes a linear model for each explanation, the overall model across multiple explanations can be complex and non-linear.

Parameters:
  • model (nn.Module) – The reference to PyTorch model instance.

  • layer (torch.nn.Module) – Layer for which attributions are computed. The size and dimensionality of the attributions corresponds to the size and dimensionality of the layer’s input or output depending on whether we attribute to the inputs or outputs of the layer.

  • multiply_by_inputs (bool, optional) –

    Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in then that type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found here: https://arxiv.org/abs/1711.06104

    In case of LayerDeepLiftShap, if multiply_by_inputs is set to True, final sensitivity scores are being multiplied by layer activations for inputs - layer activations for baselines This flag applies only if custom_attribution_func is set to None.

attribute(inputs, baselines, target=None, additional_forward_args=None, return_convergence_delta=False, attribute_to_layer_input=False, custom_attribution_func=None)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which layer attributions are computed. If model takes a single tensor as input, a single input tensor should be provided. If model takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (Tensor, tuple[Tensor, ...], or Callable) –

    Baselines define reference samples that are compared with the inputs. In order to assign attribution scores DeepLift computes the differences between the inputs/outputs and corresponding references. Baselines can be provided as:

    • a single tensor, if inputs is a single tensor, with the first dimension equal to the number of examples in the baselines’ distribution. The remaining dimensions must match with input tensor’s dimension starting from the second dimension.

    • a tuple of tensors, if inputs is a tuple of tensors, with the first dimension of any tensor inside the tuple equal to the number of examples in the baseline’s distribution. The remaining dimensions must match the dimensions of the corresponding input tensor starting from the second dimension.

    • callable function, optionally takes inputs as an argument and either returns a single tensor or a tuple of those.

    It is recommended that the number of samples in the baselines’ tensors is larger than one.

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to model in order, following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attributions with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer inputs, otherwise it will be computed with respect to layer outputs. Note that currently it assumes that both the inputs and outputs of internal layers are single tensors. Support for multiple tensors will be added later. Default: False

  • custom_attribution_func (Callable, optional) –

    A custom function for computing final attribution scores. This function can take at least one and at most three arguments with the following signature:

    • custom_attribution_func(multipliers)

    • custom_attribution_func(multipliers, inputs)

    • custom_attribution_func(multipliers, inputs, baselines)

    In case this function is not provided, we use the default logic defined as: multipliers * (inputs - baselines) It is assumed that all input arguments, multipliers, inputs and baselines are provided in tuples of same length. custom_attribution_func returns a tuple of attribution tensors that have the same length as the inputs. Default: None

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Attribution score computed based on DeepLift’s rescale rule with respect to layer’s inputs or outputs. Attributions will always be the same size as the provided layer’s inputs or outputs, depending on whether we attribute to the inputs or outputs of the layer. Attributions are returned in a tuple based on whether the layer inputs / outputs are contained in a tuple from a forward hook. For standard modules, inputs of a single tensor are usually wrapped in a tuple, while outputs of a single tensor are not.

  • delta (Tensor, returned if return_convergence_delta=True):

    This is computed using the property that the total sum of model(inputs) - model(baselines) must be very close to the total sum of attributions computed based on approximated SHAP values using DeepLift’s rescale rule. Delta is calculated for each example input and baseline pair, meaning that the number of elements in returned delta tensor is equal to the number of examples in input * number of examples in baseline. The deltas are ordered in the first place by input example, followed by the baseline. Note that the logic described for deltas is guaranteed when the default logic for attribution computations is used, meaning that the custom_attribution_func=None, otherwise it is not guaranteed and depends on the specifics of the custom_attribution_func.

Return type:

attributions or 2-element tuple of attributions, delta

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> net = ImageClassifier()
>>> # creates an instance of LayerDeepLift to interpret target
>>> # class 1 with respect to conv4 layer.
>>> dl = LayerDeepLiftShap(net, net.conv4)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes shap values using deeplift for class 3.
>>> attribution = dl.attribute(input, target=3)

Layer GradientShap

class captum.attr.LayerGradientShap(forward_func, layer, device_ids=None, multiply_by_inputs=True)[source]

Implements gradient SHAP for layer based on the implementation from SHAP’s primary author. For reference, please, view:

https://github.com/slundberg/shap#deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models

A Unified Approach to Interpreting Model Predictions https://papers.nips.cc/paper7062-a-unified-approach-to-interpreting-model-predictions

GradientShap approximates SHAP values by computing the expectations of gradients by randomly sampling from the distribution of baselines/references. It adds white noise to each input sample n_samples times, selects a random baseline from baselines’ distribution and a random point along the path between the baseline and the input, and computes the gradient of outputs with respect to selected random points in chosen layer. The final SHAP values represent the expected values of gradients * (layer_attr_inputs - layer_attr_baselines).

GradientShap makes an assumption that the input features are independent and that the explanation model is linear, meaning that the explanations are modeled through the additive composition of feature effects. Under those assumptions, SHAP value can be approximated as the expectation of gradients that are computed for randomly generated n_samples input samples after adding gaussian noise n_samples times to each input for different baselines/references.

In some sense it can be viewed as an approximation of integrated gradients by computing the expectations of gradients for different baselines.

Current implementation uses Smoothgrad from NoiseTunnel in order to randomly draw samples from the distribution of baselines, add noise to input samples and compute the expectation (smoothgrad).

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module) – Layer for which attributions are computed. Output size of attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to attribution of each neuron in the input or output of this layer.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

  • multiply_by_inputs (bool, optional) –

    Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in, then this type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found here: https://arxiv.org/abs/1711.06104

    In case of layer gradient shap, if multiply_by_inputs is set to True, the sensitivity scores for scaled inputs are being multiplied by layer activations for inputs - layer activations for baselines.

attribute(inputs, baselines, n_samples=5, stdevs=0.0, target=None, additional_forward_args=None, return_convergence_delta=False, attribute_to_layer_input=False)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input which are used to compute SHAP attribution values for a given layer. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (Tensor, tuple[Tensor, ...], or Callable) –

    Baselines define the starting point from which expectation is computed and can be provided as:

    • a single tensor, if inputs is a single tensor, with the first dimension equal to the number of examples in the baselines’ distribution. The remaining dimensions must match with input tensor’s dimension starting from the second dimension.

    • a tuple of tensors, if inputs is a tuple of tensors, with the first dimension of any tensor inside the tuple equal to the number of examples in the baseline’s distribution. The remaining dimensions must match the dimensions of the corresponding input tensor starting from the second dimension.

    • callable function, optionally takes inputs as an argument and either returns a single tensor or a tuple of those.

    It is recommended that the number of samples in the baselines’ tensors is larger than one.

  • n_samples (int, optional) – The number of randomly generated examples per sample in the input batch. Random examples are generated by adding gaussian random noise to each sample. Default: 5 if n_samples is not provided.

  • stdevs (float or tuple of float, optional) – The standard deviation of gaussian noise with zero mean that is added to each input in the batch. If stdevs is a single float value then that same value is used for all inputs. If it is a tuple, then it must have the same length as the inputs tuple. In this case, each stdev value in the stdevs tuple corresponds to the input with the same index in the inputs tuple. Default: 0.0

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It can contain a tuple of ND tensors or any arbitrary python type of any shape. In case of the ND tensor the first dimension of the tensor must correspond to the batch size. It will be repeated for each n_steps for each randomly generated input sample. Note that the attributions are not computed with respect to these arguments. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer input, otherwise it will be computed with respect to layer output. Note that currently it is assumed that either the input or the output of internal layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Attribution score computed based on GradientSHAP with respect to layer’s input or output. Attributions will always be the same size as the provided layer’s inputs or outputs, depending on whether we attribute to the inputs or outputs of the layer. Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned.

  • delta (Tensor, returned if return_convergence_delta=True):

    This is computed using the property that the total sum of forward_func(inputs) - forward_func(baselines) must be very close to the total sum of the attributions based on layer gradient SHAP. Delta is calculated for each example in the input after adding n_samples times gaussian noise to each of them. Therefore, the dimensionality of the deltas tensor is equal to the number of examples in the input * n_samples The deltas are ordered by each input example and n_samples noisy samples generated for it.

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> net = ImageClassifier()
>>> layer_grad_shap = LayerGradientShap(net, net.linear1)
>>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
>>> # choosing baselines randomly
>>> baselines = torch.randn(20, 3, 32, 32)
>>> # Computes gradient SHAP of output layer when target is equal
>>> # to 0 with respect to the layer linear1.
>>> # Attribution size matches to the size of the linear1 layer
>>> attribution = layer_grad_shap.attribute(input, baselines,
                                            target=5)

Return type:

attributions or 2-element tuple of attributions, delta

has_convergence_delta()[source]

This method informs the user whether the attribution algorithm provides a convergence delta (aka an approximation error) or not. Convergence delta may serve as a proxy of correctness of attribution algorithm’s approximation. If deriving attribution class provides a compute_convergence_delta method, it should override both compute_convergence_delta and has_convergence_delta methods.

Returns:

Returns whether the attribution algorithm provides a convergence delta (aka approximation error) or not.

Return type:

bool

Layer Integrated Gradients

class captum.attr.LayerIntegratedGradients(forward_func, layer, device_ids=None, multiply_by_inputs=True)[source]

Layer Integrated Gradients is a variant of Integrated Gradients that assigns an importance score to layer inputs or outputs, depending on whether we attribute to the former or to the latter one.

Integrated Gradients is an axiomatic model interpretability algorithm that attributes / assigns an importance score to each input feature by approximating the integral of gradients of the model’s output with respect to the inputs along the path (straight line) from given baselines / references to inputs.

Baselines can be provided as input arguments to attribute method. To approximate the integral we can choose to use either a variant of Riemann sum or Gauss-Legendre quadrature rule.

More details regarding the integrated gradients method can be found in the original paper: https://arxiv.org/abs/1703.01365

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (ModuleOrModuleList) –

    Layer or list of layers for which attributions are computed. For each layer the output size of the attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to the attribution of each neuron in the input or output of this layer.

    Please note that layers to attribute on cannot be dependent on each other. That is, a subset of layers in layer cannot produce the inputs for another layer.

    For example, if your model is of a simple linked-list based graph structure (think nn.Sequence), e.g. x -> l1 -> l2 -> l3 -> output. If you pass in any one of those layers, you cannot pass in another due to the dependence, e.g. if you pass in l2 you cannot pass in l1 or l3.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself, then it is not necessary to provide this argument.

  • multiply_by_inputs (bool, optional) –

    Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in, then this type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found here: https://arxiv.org/abs/1711.06104

    In case of layer integrated gradients, if multiply_by_inputs is set to True, final sensitivity scores are being multiplied by layer activations for inputs - layer activations for baselines.

attribute(inputs, baselines=None, target=None, additional_forward_args=None, n_steps=50, method='gausslegendre', internal_batch_size=None, return_convergence_delta=False, attribute_to_layer_input=False)[source]

This method attributes the output of the model with given target index (in case it is provided, otherwise it assumes that output is a scalar) to layer inputs or outputs of the model, depending on whether attribute_to_layer_input is set to True or False, using the approach described above.

In addition to that it also returns, if return_convergence_delta is set to True, integral approximation delta based on the completeness property of integrated gradients.

Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which layer integrated gradients are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (scalar, Tensor, tuple of scalar, or Tensor, optional) –

    Baselines define the starting point from which integral is computed and can be provided as:

    • a single tensor, if inputs is a single tensor, with exactly the same dimensions as inputs or the first dimension is one and the remaining dimensions match with inputs.

    • a single scalar, if inputs is a single tensor, which will be broadcasted for each input value in input tensor.

    • a tuple of tensors or scalars, the baseline corresponding to each tensor in the inputs’ tuple can be:

      • either a tensor with matching dimensions to corresponding tensor in the inputs’ tuple or the first dimension is one and the remaining dimensions match with the corresponding input tensor.

      • or a scalar, corresponding to a tensor in the inputs’ tuple. This scalar value is broadcasted for corresponding input tensor.

    In the cases when baselines is not provided, we internally use zero scalar corresponding to each input tensor.

    Default: None

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) –

    If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs.

    For a tensor, the first dimension of the tensor must correspond to the number of examples. It will be repeated for each of n_steps along the integrated path. For all other types, the given argument is used for all forward evaluations.

    Note that attributions are not computed with respect to these arguments. Default: None

  • n_steps (int, optional) – The number of steps used by the approximation method. Default: 50.

  • method (str, optional) – Method for approximating the integral, one of riemann_right, riemann_left, riemann_middle, riemann_trapezoid or gausslegendre. Default: gausslegendre if no method is provided.

  • internal_batch_size (int, optional) –

    Divides total #steps * #examples data points into chunks of size at most internal_batch_size, which are computed (forward / backward passes) sequentially. internal_batch_size must be at least equal to #examples.

    For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain internal_batch_size / num_devices examples. If internal_batch_size is None, then all evaluations are processed in one batch. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • attribute_to_layer_input (bool, optional) –

    Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer input, otherwise it will be computed with respect to layer output.

    Note that currently it is assumed that either the input or the output of internal layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Integrated gradients with respect to layer’s inputs or outputs. Attributions will always be the same size and dimensionality as the input or output of the given layer, depending on whether we attribute to the inputs or outputs of the layer which is decided by the input flag attribute_to_layer_input.

    For a single layer, attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned.

    For multiple layers, attributions will always be returned as a list. Each element in this list will be equivalent to that of a single layer output, i.e. in the case that one layer, in the given layers, inputs / outputs multiple tensors: the corresponding output element will be a tuple of tensors. The ordering of the outputs will be the same order as the layers given in the constructor.

  • delta (Tensor, returned if return_convergence_delta=True):

    The difference between the total approximated and true integrated gradients. This is computed using the property that the total sum of forward_func(inputs) - forward_func(baselines) must equal the total sum of the integrated gradient. Delta is calculated per example, meaning that the number of elements in returned delta tensor is equal to the number of examples in inputs.

Return type:

attributions or 2-element tuple of attributions, delta

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx12x32x32.
>>> net = ImageClassifier()
>>> lig = LayerIntegratedGradients(net, net.conv1)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes layer integrated gradients for class 3.
>>> # attribution size matches layer output, Nx12x32x32
>>> attribution = lig.attribute(input, target=3)
has_convergence_delta()[source]

This method informs the user whether the attribution algorithm provides a convergence delta (aka an approximation error) or not. Convergence delta may serve as a proxy of correctness of attribution algorithm’s approximation. If deriving attribution class provides a compute_convergence_delta method, it should override both compute_convergence_delta and has_convergence_delta methods.

Returns:

Returns whether the attribution algorithm provides a convergence delta (aka approximation error) or not.

Return type:

bool

Layer Feature Ablation

class captum.attr.LayerFeatureAblation(forward_func, layer, device_ids=None)[source]

A perturbation based approach to computing layer attribution, involving replacing values in the input / output of a layer with a given baseline / reference, and computing the difference in output. By default, each neuron (scalar input / output value) within the layer is replaced independently. Passing a layer mask allows grouping neurons to be ablated together. Each neuron in the group will be given the same attribution value equal to the change in target as a result of ablating the entire neuron group.

Parameters:
  • forward_func (Callable) – The forward function of the model or any modification of it

  • layer (torch.nn.Module) – Layer for which attributions are computed. Output size of attribute matches this layer’s input or output dimensions, depending on whether we attribute to the inputs or outputs of the layer, corresponding to attribution of each neuron in the input or output of this layer.

  • device_ids (list[int]) – Device ID list, necessary only if forward_func applies a DataParallel model. This allows reconstruction of intermediate outputs from batched results across devices. If forward_func is given as the DataParallel model itself (or otherwise has a device_ids attribute with the device ID list), then it is not necessary to provide this argument.

attribute(inputs, layer_baselines=None, target=None, additional_forward_args=None, layer_mask=None, attribute_to_layer_input=False, perturbations_per_eval=1)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which layer attributions are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • layer_baselines (scalar, Tensor, tuple of scalar, or Tensor, optional) – Layer baselines define reference values which replace each layer input / output value when ablated. Layer baselines should be a single tensor with dimensions matching the input / output of the target layer (or broadcastable to match it), based on whether we are attributing to the input or output of the target layer. In the cases when baselines is not provided, we internally use zero as the baseline for each neuron. Default: None

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • layer_mask (Tensor or tuple[Tensor, ...], optional) – layer_mask defines a mask for the layer, grouping elements of the layer input / output which should be ablated together. layer_mask should be a single tensor with dimensions matching the input / output of the target layer (or broadcastable to match it), based on whether we are attributing to the input or output of the target layer. layer_mask should contain integers in the range 0 to num_groups - 1, and all elements with the same value are considered to be in the same group. If None, then a layer mask is constructed which assigns each neuron within the layer as a separate group, which is ablated independently. Default: None

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attributions with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer’s inputs, otherwise it will be computed with respect to layer’s outputs. Note that currently it is assumed that either the input or the output of the layer, depending on whether we attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False

  • perturbations_per_eval (int, optional) – Allows ablation of multiple neuron (groups) to be processed simultaneously in one call to forward_fn. Each forward pass will contain a maximum of perturbations_per_eval * #examples samples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain at most (perturbations_per_eval * #examples) / num_devices samples. Default: 1

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Attribution of each neuron in given layer input or output. Attributions will always be the same size as the input or output of the given layer, depending on whether we attribute to the inputs or outputs of the layer which is decided by the input flag attribute_to_layer_input Attributions are returned in a tuple if the layer inputs / outputs contain multiple tensors, otherwise a single tensor is returned.

Return type:

Tensor or tuple[Tensor, …] of attributions

Examples:

>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
>>> # and returns an Nx3 tensor of class probabilities.
>>> # It contains an attribute conv1, which is an instance of nn.conv2d,
>>> # and the output of this layer has dimensions Nx12x3x3.
>>> net = SimpleClassifier()
>>> # Generating random input with size 2 x 4 x 4
>>> input = torch.randn(2, 4, 4)
>>> # Defining LayerFeatureAblation interpreter
>>> ablator = LayerFeatureAblation(net, net.conv1)
>>> # Computes ablation attribution, ablating each of the 108
>>> # neurons independently.
>>> attr = ablator.attribute(input, target=1)
>>> # Alternatively, we may want to ablate neurons in groups, e.g.
>>> # grouping all the layer outputs in the same row.
>>> # This can be done by creating a layer mask as follows, which
>>> # defines the groups of layer inputs / outouts, e.g.:
>>> # +---+---+---+
>>> # | 0 | 0 | 0 |
>>> # +---+---+---+
>>> # | 1 | 1 | 1 |
>>> # +---+---+---+
>>> # | 2 | 2 | 2 |
>>> # +---+---+---+
>>> # With this mask, all the 36 neurons in a row / channel are ablated
>>> # simultaneously, and the attribution for each neuron in the same
>>> # group (0 - 2) per example are the same.
>>> # The attributions can be calculated as follows:
>>> # layer mask has dimensions 1 x 3 x 3
>>> layer_mask = torch.tensor([[[0,0,0],[1,1,1],
>>>                             [2,2,2]]])
>>> attr = ablator.attribute(input, target=1,
>>>                          layer_mask=layer_mask)

Layer LRP

class captum.attr.LayerLRP(model, layer)[source]

Layer-wise relevance propagation is based on a backward propagation mechanism applied sequentially to all layers of the model. Here, the model output score represents the initial relevance which is decomposed into values for each neuron of the underlying layers. The decomposition is defined by rules that are chosen for each layer, involving its weights and activations. Details on the model can be found in the original paper [https://doi.org/10.1371/journal.pone.0130140]. The implementation is inspired by the tutorial of the same group [https://doi.org/10.1016/j.dsp.2017.10.011] and the publication by Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW].

Parameters:
  • model (Module) – The forward function of the model or any modification of it. Custom rules for a given layer need to be defined as attribute module.rule and need to be of type PropagationRule.

  • layer (torch.nn.Module or list(torch.nn.Module)) – Layer or layers for which attributions are computed. The size and dimensionality of the attributions corresponds to the size and dimensionality of the layer’s input or output depending on whether we attribute to the inputs or outputs of the layer. If value is None, the relevance for all layers is returned in attribution.

attribute(inputs, target=None, additional_forward_args=None, return_convergence_delta=False, attribute_to_layer_input=False, verbose=False)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which relevance is propagated. If model takes a single tensor as input, a single input tensor should be provided. If model takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples, and if multiple input tensors are provided, the examples must be aligned appropriately.

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single

      integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching

      the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1

      elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of

      examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (tuple, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to model in order, following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • attribute_to_layer_input (bool, optional) – Indicates whether to compute the attribution with respect to the layer input or output. If attribute_to_layer_input is set to True then the attributions will be computed with respect to layer input, otherwise it will be computed with respect to layer output.

  • verbose (bool, optional) – Indicates whether information on application of rules is printed during propagation. Default: False

Return type:

Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]], Tuple[Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], Union[Tensor, List[Tensor]]]]

Returns:

Tensor or tuple[Tensor, …] of attributions or 2-element tuple of attributions, delta or list of attributions and delta:

  • attributions (Tensor or tuple[Tensor, …]):

    The propagated relevance values with respect to each input feature. Attributions will always be the same size as the provided inputs, with each value providing the attribution of the corresponding input index. If a single tensor is provided as inputs, a single tensor is returned. If a tuple is provided for inputs, a tuple of corresponding sized tensors is returned. The sum of attributions is one and not corresponding to the prediction score as in other implementations. If attributions for all layers are returned (layer=None) a list of tensors or tuples of tensors is returned with entries for each layer.

  • delta (Tensor or list of Tensor

    returned if return_convergence_delta=True): Delta is calculated per example, meaning that the number of elements in returned delta tensor is equal to the number of examples in input. If attributions for all layers are returned (layer=None) a list of tensors is returned with entries for each layer.

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities. It has one
>>> # Conv2D and a ReLU layer.
>>> net = ImageClassifier()
>>> layer_lrp = LayerLRP(net, net.conv1)
>>> input = torch.randn(3, 3, 32, 32)
>>> # Attribution size matches input size: 3x3x32x32
>>> attribution = layer_lrp.attribute(input, target=5)