DeepLift

class captum.attr.DeepLift(model, multiply_by_inputs=True, eps=1e-10)[source]

Implements DeepLIFT algorithm based on the following paper: Learning Important Features Through Propagating Activation Differences, Avanti Shrikumar, et. al. https://arxiv.org/abs/1704.02685

and the gradient formulation proposed in: Towards better understanding of gradient-based attribution methods for deep neural networks, Marco Ancona, et.al. https://openreview.net/pdf?id=Sy21R9JAW

This implementation supports only Rescale rule. RevealCancel rule will be supported in later releases. In addition to that, in order to keep the implementation cleaner, DeepLIFT for internal neurons and layers extends current implementation and is implemented separately in LayerDeepLift and NeuronDeepLift. Although DeepLIFT’s(Rescale Rule) attribution quality is comparable with Integrated Gradients, it runs significantly faster than Integrated Gradients and is preferred for large datasets.

Currently we only support a limited number of non-linear activations but the plan is to expand the list in the future.

Note: As we know, currently we cannot access the building blocks, of PyTorch’s built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid. Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs with performance similar to built-in ones using TorchScript. More details on how to build custom RNNs can be found here: https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

Parameters:
  • model (nn.Module) – The reference to PyTorch model instance.

  • multiply_by_inputs (bool, optional) –

    Indicates whether to factor model inputs’ multiplier in the final attribution scores. In the literature this is also known as local vs global attribution. If inputs’ multiplier isn’t factored in then that type of attribution method is also called local attribution. If it is, then that type of attribution method is called global. More detailed can be found here: https://arxiv.org/abs/1711.06104

    In case of DeepLift, if multiply_by_inputs is set to True, final sensitivity scores are being multiplied by (inputs - baselines). This flag applies only if custom_attribution_func is set to None.

  • eps (float, optional) – A value at which to consider output/input change significant when computing the gradients for non-linear layers. This is useful to adjust, depending on your model’s bit depth, to avoid numerical issues during the gradient computation. Default: 1e-10

attribute(inputs, baselines=None, target=None, additional_forward_args=None, return_convergence_delta=False, custom_attribution_func=None)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Input for which attributions are computed. If model takes a single tensor as input, a single input tensor should be provided. If model takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately.

  • baselines (scalar, Tensor, tuple of scalar, or Tensor, optional) –

    Baselines define reference samples that are compared with the inputs. In order to assign attribution scores DeepLift computes the differences between the inputs/outputs and corresponding references. Baselines can be provided as:

    • a single tensor, if inputs is a single tensor, with exactly the same dimensions as inputs or the first dimension is one and the remaining dimensions match with inputs.

    • a single scalar, if inputs is a single tensor, which will be broadcasted for each input value in input tensor.

    • a tuple of tensors or scalars, the baseline corresponding to each tensor in the inputs’ tuple can be:

      • either a tensor with matching dimensions to corresponding tensor in the inputs’ tuple or the first dimension is one and the remaining dimensions match with the corresponding input tensor.

      • or a scalar, corresponding to a tensor in the inputs’ tuple. This scalar value is broadcasted for corresponding input tensor.

    In the cases when baselines is not provided, we internally use zero scalar corresponding to each input tensor.

    Default: None

  • target (int, tuple, Tensor, or list, optional) –

    Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either:

    • a single integer or a tensor containing a single integer, which is applied to all input examples

    • a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example.

    For outputs with > 2 dimensions, targets can be either:

    • A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples.

    • A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example.

    Default: None

  • additional_forward_args (Any, optional) – If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to model in order, following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None

  • return_convergence_delta (bool, optional) – Indicates whether to return convergence delta or not. If return_convergence_delta is set to True convergence delta will be returned in a tuple following attributions. Default: False

  • custom_attribution_func (Callable, optional) –

    A custom function for computing final attribution scores. This function can take at least one and at most three arguments with the following signature:

    • custom_attribution_func(multipliers)

    • custom_attribution_func(multipliers, inputs)

    • custom_attribution_func(multipliers, inputs, baselines)

    In case this function is not provided, we use the default logic defined as: multipliers * (inputs - baselines) It is assumed that all input arguments, multipliers, inputs and baselines are provided in tuples of same length. custom_attribution_func returns a tuple of attribution tensors that have the same length as the inputs.

    Default: None

Returns:

  • attributions (Tensor or tuple[Tensor, …]):

    Attribution score computed based on DeepLift rescale rule with respect to each input feature. Attributions will always be the same size as the provided inputs, with each value providing the attribution of the corresponding input index. If a single tensor is provided as inputs, a single tensor is returned. If a tuple is provided for inputs, a tuple of corresponding sized tensors is returned.

  • delta (Tensor, returned if return_convergence_delta=True):

    This is computed using the property that the total sum of model(inputs) - model(baselines) must equal the total sum of the attributions computed based on DeepLift’s rescale rule. Delta is calculated per example, meaning that the number of elements in returned delta tensor is equal to the number of examples in input. Note that the logic described for deltas is guaranteed when the default logic for attribution computations is used, meaning that the custom_attribution_func=None, otherwise it is not guaranteed and depends on the specifics of the custom_attribution_func.

Return type:

attributions or 2-element tuple of attributions, delta

Examples:

>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
>>> # and returns an Nx10 tensor of class probabilities.
>>> net = ImageClassifier()
>>> dl = DeepLift(net)
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
>>> # Computes deeplift attribution scores for class 3.
>>> attribution = dl.attribute(input, target=3)
attribute_future()[source]

This method is not implemented for DeepLift.

Return type:

Callable

has_convergence_delta()[source]

This method informs the user whether the attribution algorithm provides a convergence delta (aka an approximation error) or not. Convergence delta may serve as a proxy of correctness of attribution algorithm’s approximation. If deriving attribution class provides a compute_convergence_delta method, it should override both compute_convergence_delta and has_convergence_delta methods.

Returns:

Returns whether the attribution algorithm provides a convergence delta (aka approximation error) or not.

Return type:

bool