Source code for captum.attr._core.input_x_gradient

#!/usr/bin/env python3
from typing import Any, Callable

from captum.log import log_usage

from ..._utils.common import _format_input, _format_output, _is_tuple
from ..._utils.gradient import apply_gradient_requirements, undo_gradient_requirements
from ..._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from .._utils.attribution import GradientAttribution


[docs]class InputXGradient(GradientAttribution): r""" A baseline approach for computing the attribution. It multiplies input with the gradient with respect to input. https://arxiv.org/abs/1605.01713 """ def __init__(self, forward_func: Callable) -> None: r""" Args: forward_func (callable): The forward function of the model or any modification of it """ GradientAttribution.__init__(self, forward_func)
[docs] @log_usage() def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, additional_forward_args: Any = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: inputs (tensor or tuple of tensors): Input for which attributions are computed. If forward_func takes a single tensor as input, a single input tensor should be provided. If forward_func takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately. target (int, tuple, tensor or list, optional): Output indices for which gradients are computed (for classification cases, this is usually the target class). If the network returns a scalar value per example, no target index is necessary. For general 2D outputs, targets can be either: - a single integer or a tensor containing a single integer, which is applied to all input examples - a list of integers or a 1D tensor, with length matching the number of examples in inputs (dim 0). Each integer is applied as the target for the corresponding example. For outputs with > 2 dimensions, targets can be either: - A single tuple, which contains #output_dims - 1 elements. This target index is applied to all examples. - A list of tuples with length equal to the number of examples in inputs (dim 0), and each tuple containing #output_dims - 1 elements. Each tuple is applied as the target for the corresponding example. Default: None additional_forward_args (any, optional): If the forward function requires additional arguments other than the inputs for which attributions should not be computed, this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. Note that attributions are not computed with respect to these arguments. Default: None Returns: *tensor* or tuple of *tensors* of **attributions**: - **attributions** (*tensor* or tuple of *tensors*): The input x gradient with respect to each input feature. Attributions will always be the same size as the provided inputs, with each value providing the attribution of the corresponding input index. If a single tensor is provided as inputs, a single tensor is returned. If a tuple is provided for inputs, a tuple of corresponding sized tensors is returned. Examples:: >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, >>> # and returns an Nx10 tensor of class probabilities. >>> net = ImageClassifier() >>> # Generating random input with size 2x3x3x32 >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) >>> # Defining InputXGradient interpreter >>> input_x_gradient = InputXGradient(net) >>> # Computes inputXgradient for class 4. >>> attribution = input_x_gradient.attribute(input, target=4) """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. is_inputs_tuple = _is_tuple(inputs) inputs = _format_input(inputs) gradient_mask = apply_gradient_requirements(inputs) gradients = self.gradient_func( self.forward_func, inputs, target, additional_forward_args ) attributions = tuple( input * gradient for input, gradient in zip(inputs, gradients) ) undo_gradient_requirements(inputs, gradient_mask) return _format_output(is_inputs_tuple, attributions)
@property def multiplies_by_inputs(self): return True