#!/usr/bin/env python3
# pyre-strict
from typing import Callable, cast, Generator, Optional, Tuple, Union
import torch
from captum._utils.models.linear_model import SkLearnLinearRegression
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import construct_feature_mask, Lime
from captum.attr._utils.common import _format_input_baseline
from captum.log import log_usage
from torch import Tensor
from torch.distributions.categorical import Categorical
[docs]
class KernelShap(Lime):
r"""
Kernel SHAP is a method that uses the LIME framework to compute
Shapley Values. Setting the loss function, weighting kernel and
regularization terms appropriately in the LIME framework allows
theoretically obtaining Shapley Values more efficiently than
directly computing Shapley Values.
More information regarding this method and proof of equivalence
can be found in the original paper here:
https://arxiv.org/abs/1705.07874
"""
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
r"""
Args:
forward_func (Callable): The forward function of the model or
any modification of it.
"""
Lime.__init__(
self,
forward_func,
interpretable_model=SkLearnLinearRegression(),
similarity_func=self.kernel_shap_similarity_kernel,
perturb_func=self.kernel_shap_perturb_generator,
)
self.inf_weight = 1000000.0
[docs]
@log_usage()
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
r"""
This method attributes the output of the model with given target index
(in case it is provided, otherwise it assumes that output is a
scalar) to the inputs of the model using the approach described above,
training an interpretable model based on KernelSHAP and returning a
representation of the interpretable model.
It is recommended to only provide a single example as input (tensors
with first dimension or batch size = 1). This is because LIME / KernelShap
is generally used for sample-based interpretability, training a separate
interpretable model to explain a model's prediction on each individual example.
A batch of inputs can also be provided as inputs, similar to
other perturbation-based attribution methods. In this case, if forward_fn
returns a scalar per example, attributions will be computed for each
example independently, with a separate interpretable model trained for each
example. Note that provided similarity and perturbation functions will be
provided each example separately (first dimension = 1) in this case.
If forward_fn returns a scalar per batch (e.g. loss), attributions will
still be computed using a single interpretable model for the full batch.
In this case, similarity and perturbation functions will be provided the
same original input containing the full batch.
The number of interpretable features is determined from the provided
feature mask, or if none is provided, from the default feature mask,
which considers each scalar input as a separate feature. It is
generally recommended to provide a feature mask which groups features
into a small number of interpretable features / components (e.g.
superpixels in images).
Args:
inputs (Tensor or tuple[Tensor, ...]): Input for which KernelShap
is computed. If forward_func takes a single
tensor as input, a single input tensor should be provided.
If forward_func takes multiple tensors as input, a tuple
of the input tensors should be provided. It is assumed
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
baselines (scalar, Tensor, tuple of scalar, or Tensor, optional):
Baselines define the reference value which replaces each
feature when the corresponding interpretable feature
is set to 0.
Baselines can be provided as:
- a single tensor, if inputs is a single tensor, with
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
- a single scalar, if inputs is a single tensor, which will
be broadcasted for each input value in input tensor.
- a tuple of tensors or scalars, the baseline corresponding
to each tensor in the inputs' tuple can be:
- either a tensor with matching dimensions to
corresponding tensor in the inputs' tuple
or the first dimension is one and the remaining
dimensions match with the corresponding
input tensor.
- or a scalar, corresponding to a tensor in the
inputs' tuple. This scalar value is broadcasted
for corresponding input tensor.
In the cases when `baselines` is not provided, we internally
use zero scalar corresponding to each input tensor.
Default: None
target (int, tuple, Tensor, or list, optional): Output indices for
which surrogate model is trained
(for classification cases,
this is usually the target class).
If the network returns a scalar value per example,
no target index is necessary.
For general 2D outputs, targets can be either:
- a single integer or a tensor containing a single
integer, which is applied to all input examples
- a list of integers or a 1D tensor, with length matching
the number of examples in inputs (dim 0). Each integer
is applied as the target for the corresponding example.
For outputs with > 2 dimensions, targets can be either:
- A single tuple, which contains #output_dims - 1
elements. This target index is applied to all examples.
- A list of tuples with length equal to the number of
examples in inputs (dim 0), and each tuple containing
#output_dims - 1 elements. Each tuple is applied as the
target for the corresponding example.
Default: None
additional_forward_args (Any, optional): If the forward function
requires additional arguments other than the inputs for
which attributions should not be computed, this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. It will be
repeated for each of `n_steps` along the integrated
path. For all other types, the given argument is used
for all forward evaluations.
Note that attributions are not computed with respect
to these arguments.
Default: None
feature_mask (Tensor or tuple[Tensor, ...], optional):
feature_mask defines a mask for the input, grouping
features which correspond to the same
interpretable feature. feature_mask
should contain the same number of tensors as inputs.
Each tensor should
be the same size as the corresponding input or
broadcastable to match the input tensor. Values across
all tensors should be integers in the range 0 to
num_interp_features - 1, and indices corresponding to the
same feature should have the same value.
Note that features are grouped across tensors
(unlike feature ablation and occlusion), so
if the same index is used in different tensors, those
features are still grouped and added simultaneously.
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature.
Default: None
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
perturbations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(perturbations_per_eval * #examples) / num_devices
samples.
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
Default: 1
return_input_shape (bool, optional): Determines whether the returned
tensor(s) only contain the coefficients for each interp-
retable feature from the trained surrogate model, or
whether the returned attributions match the input shape.
When return_input_shape is True, the return type of attribute
matches the input shape, with each element containing the
coefficient of the corresponding interpretable feature.
All elements with the same value in the feature mask
will contain the same coefficient in the returned
attributions. If return_input_shape is False, a 1D
tensor is returned, containing only the coefficients
of the trained interpretable model, with length
num_interp_features.
show_progress (bool, optional): Displays the progress of computation.
It will try to use tqdm if available for advanced features
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
Returns:
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
- **attributions** (*Tensor* or *tuple[Tensor, ...]*):
The attributions with respect to each input feature.
If return_input_shape = True, attributions will be
the same size as the provided inputs, with each value
providing the coefficient of the corresponding
interpretale feature.
If return_input_shape is False, a 1D
tensor is returned, containing only the coefficients
of the trained interpreatable models, with length
num_interp_features.
Examples::
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
>>> # and returns an Nx3 tensor of class probabilities.
>>> net = SimpleClassifier()
>>> # Generating random input with size 1 x 4 x 4
>>> input = torch.randn(1, 4, 4)
>>> # Defining KernelShap interpreter
>>> ks = KernelShap(net)
>>> # Computes attribution, with each of the 4 x 4 = 16
>>> # features as a separate interpretable feature
>>> attr = ks.attribute(input, target=1, n_samples=200)
>>> # Alternatively, we can group each 2x2 square of the inputs
>>> # as one 'interpretable' feature and perturb them together.
>>> # This can be done by creating a feature mask as follows, which
>>> # defines the feature groups, e.g.:
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # With this mask, all inputs with the same value are set to their
>>> # baseline value, when the corresponding binary interpretable
>>> # feature is set to 0.
>>> # The attributions can be calculated as follows:
>>> # feature mask has dimensions 1 x 4 x 4
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
>>> [2,2,3,3],[2,2,3,3]]])
>>> # Computes KernelSHAP attributions with feature mask.
>>> attr = ks.attribute(input, target=1, feature_mask=feature_mask)
"""
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
feature_mask, num_interp_features = construct_feature_mask(
feature_mask, formatted_inputs
)
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
denom = num_features_list * (num_interp_features - num_features_list)
probs = torch.tensor((num_interp_features - 1)) / denom
probs[0] = 0.0
return self._attribute_kwargs(
inputs=inputs,
baselines=baselines,
target=target,
additional_forward_args=additional_forward_args,
feature_mask=feature_mask,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
return_input_shape=return_input_shape,
num_select_distribution=Categorical(probs),
show_progress=show_progress,
)
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
[docs]
def attribute_future(self) -> Callable:
r"""
This method is not implemented for KernelShap.
"""
raise NotImplementedError("attribute_future is not implemented for KernelShap")
def kernel_shap_similarity_kernel(
self,
_,
__,
interpretable_sample: Tensor,
**kwargs: object,
) -> Tensor:
assert (
"num_interp_features" in kwargs
), "Must provide num_interp_features to use default similarity kernel"
num_selected_features = int(interpretable_sample.sum(dim=1).item())
num_features = kwargs["num_interp_features"]
if num_selected_features == 0 or num_selected_features == num_features:
# weight should be theoretically infinite when
# num_selected_features = 0 or num_features
# enforcing that trained linear model must satisfy
# end-point criteria. In practice, it is sufficient to
# make this weight substantially larger so setting this
# weight to 1000000 (all other weights are 1).
similarities = self.inf_weight
else:
similarities = 1.0
return torch.tensor([similarities])
[docs]
def kernel_shap_perturb_generator(
self,
original_inp: Union[Tensor, Tuple[Tensor, ...]],
**kwargs: object,
) -> Generator[Tensor, None, None]:
r"""
Perturbations are sampled by the following process:
- Choose k (number of selected features), based on the distribution
p(k) = (M - 1) / (k * (M - k))
where M is the total number of features in the interpretable space
- Randomly select a binary vector with k ones, each sample is equally
likely. This is done by generating a random vector of normal
values and thresholding based on the top k elements.
Since there are M choose k vectors with k ones, this weighted sampling
is equivalent to applying the Shapley kernel for the sample weight,
defined as:
k(M, k) = (M - 1) / (k * (M - k) * (M choose k))
"""
assert (
"num_select_distribution" in kwargs and "num_interp_features" in kwargs
), (
"num_select_distribution and num_interp_features are necessary"
" to use kernel_shap_perturb_func"
)
if isinstance(original_inp, Tensor):
device = original_inp.device
else:
device = original_inp[0].device
num_features = cast(int, kwargs["num_interp_features"])
yield torch.ones(1, num_features, device=device, dtype=torch.long)
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
while True:
num_selected_features = cast(
Categorical, kwargs["num_select_distribution"]
).sample()
rand_vals = torch.randn(1, num_features)
threshold = torch.kthvalue(
rand_vals, num_features - num_selected_features
).values.item()
yield (rand_vals > threshold).to(device=device).long()