Source code for captum.influence._core.similarity_influence

#!/usr/bin/env python3

# pyre-strict

import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import captum._utils.common as common
import torch
from captum._utils.av import AV
from captum.attr import LayerActivation
from captum.influence._core.influence import DataInfluence
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset

r"""
Additional helper functions to calculate similarity metrics.
"""


def euclidean_distance(test: Tensor, train: Tensor) -> Tensor:
    r"""
    Calculates the pairwise euclidean distance for batches of feature vectors.
    Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *).
    Returns pairwise euclidean distance Tensor of shape (batch_size_1, batch_size_2).
    """
    similarity = torch.cdist(
        test.view(test.shape[0], -1).unsqueeze(0),
        train.view(train.shape[0], -1).unsqueeze(0),
    ).squeeze(0)
    return similarity


def cosine_similarity(test: Tensor, train: Tensor, replace_nan: int = 0) -> Tensor:
    r"""
    Calculates the pairwise cosine similarity for batches of feature vectors.
    Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *).
    Returns pairwise cosine similarity Tensor of shape (batch_size_1, batch_size_2).
    """
    test = test.view(test.shape[0], -1)
    train = train.view(train.shape[0], -1)

    test_norm = torch.linalg.norm(test, ord=2, dim=1, keepdim=True)
    train_norm = torch.linalg.norm(train, ord=2, dim=1, keepdim=True)

    test = torch.where(test_norm != 0.0, test / test_norm, Tensor([replace_nan]))
    train = torch.where(train_norm != 0.0, train / train_norm, Tensor([replace_nan])).T

    similarity = torch.mm(test, train)
    return similarity


r"""
Implements abstract DataInfluence class and provides implementation details for
similarity metric-based influence computation. Similarity metrics can be used to compare
intermediate or final activation vectors of a model for different sets of input. Then,
these can be used to draw conclusions about influential instances.

Some standard similarity metrics such as dot product similarity or euclidean distance
are provided, but the user can provide any custom similarity metric as well.
"""


[docs] class SimilarityInfluence(DataInfluence): def __init__( self, module: Module, layers: Union[str, List[str]], influence_src_dataset: Dataset, activation_dir: str, model_id: str = "", # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. similarity_metric: Callable = cosine_similarity, similarity_direction: str = "max", batch_size: int = 1, **kwargs: Any, ) -> None: r""" Args: module (torch.nn.Module): An instance of pytorch model. This model should define all of its layers as attributes of the model. layers (str or list[str]): The fully qualified layer(s) for which the activation vectors are computed. influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is used to create a PyTorch Dataloader to iterate over the dataset and its labels. This is the dataset for which we will be seeking for influential instances. In most cases this is the training dataset. activation_dir (str): The directory of the path to store and retrieve activation computations. Best practice would be to use an absolute path. model_id (str): The name/version of the model for which layer activations are being computed. Activations will be stored and loaded under the subdirectory with this name if provided. similarity_metric (Callable): This is a callable function that computes a similarity metric between two representations. For example, the representations pair could be from the training and test sets. This function must adhere to certain standards. The inputs should be torch Tensors with shape (batch_size_i/j, feature dimensions). The output Tensor should have shape (batch_size_i, batch_size_j) with scalar values corresponding to the similarity metric used for each pairwise combination from the two batches. For example, suppose we use `batch_size_1 = 16` for iterating through `influence_src_dataset`, and for the `inputs` argument we pass in a Tensor with 3 examples, i.e. batch_size_2 = 3. Also, suppose that our inputs and intermediate activations throughout the model will have dimension (N, C, H, W). Then, the feature dimensions should be flattened within this function. For example:: >>> av_test.shape torch.Size([3, N, C, H, W]) >>> av_src.shape torch.Size([16, N, C, H, W]) >>> av_test = torch.view(av_test.shape[0], -1) >>> av_test.shape torch.Size([3, N x C x H x W]) and similarly for av_src. The similarity_metric should then use these flattened tensors to return the pairwise similarity matrix. For example, `similarity_metric(av_test, av_src)` should return a tensor of shape (3, 16). batch_size (int): Batch size for iterating through `influence_src_dataset`. **kwargs: Additional key-value arguments that are necessary for specific implementation of `DataInfluence` abstract class. """ self.module = module # pyre-fixme[4]: Attribute must be annotated. self.layers = [layers] if isinstance(layers, str) else layers self.influence_src_dataset = influence_src_dataset self.activation_dir = activation_dir self.model_id = model_id self.batch_size = batch_size if similarity_direction == "max" or similarity_direction == "min": # pyre-fixme[4]: Attribute must be annotated. self.similarity_direction = similarity_direction else: raise ValueError( f"{similarity_direction} is not a valid value. " "Must be either 'max' or 'min'" ) if similarity_metric is cosine_similarity: if "replace_nan" in kwargs: # pyre-fixme[4]: Attribute must be annotated. self.replace_nan = kwargs["replace_nan"] else: self.replace_nan = -2 if self.similarity_direction == "max" else 2 similarity_metric = partial(cosine_similarity, replace_nan=self.replace_nan) self.similarity_metric = similarity_metric # pyre-fixme[4]: Attribute must be annotated. self.influence_src_dataloader = DataLoader( influence_src_dataset, batch_size, shuffle=False )
[docs] def influence( # type: ignore[override] self, inputs: Union[Tensor, Tuple[Tensor, ...]], top_k: int = 1, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Optional[Any] = None, load_src_from_disk: bool = True, **kwargs: Any, # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors. ) -> Dict: r""" Args: inputs (Tensor or tuple[Tensor, ...]): Batch of examples for which influential instances are computed. They are passed to the forward_func. The first dimension in `inputs` tensor or tuple of tensors corresponds to the batch size. A tuple of tensors is only passed in if thisis the input form that `module` accepts. top_k (int): The number of top-matching activations to return additional_forward_args (Any, optional): Additional arguments that will be passed to forward_func after inputs. load_src_from_disk (bool): Loads activations for `influence_src_dataset` where possible. Setting to False would force regeneration of activations. load_input_from_disk (bool): Regenerates activations for inputs by default and removes previous `inputs` activations that are flagged with `inputs_id`. Setting to True will load prior matching inputs activations. Note that this could lead to unexpected behavior if `inputs_id` is not configured properly and activations are loaded for a different, prior `inputs`. inputs_id (str): Used to identify inputs for loading activations. **kwargs: Additional key-value arguments that are necessary for specific implementation of `DataInfluence` abstract class. Returns: influences (dict): Returns the influential instances retrieved from `influence_src_dataset` for each test example represented through a tensor or a tuple of tensor in `inputs`. Returned influential examples are represented as dict, with keys corresponding to the layer names passed in `layers`. Each value in the dict is a tuple containing the indices and values for the top k similarities from `influence_src_dataset` by the chosen metric. The first value in the tuple corresponds to the indices corresponding to the top k most similar examples, and the second value is the similarity score. The batch dimension corresponds to the batch dimension of `inputs`. If inputs.shape[0] == 5, then dict[`layer_name`][0].shape[0] == 5. These tensors will be of shape (inputs.shape[0], top_k). """ inputs_batch_size = ( inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] ) influences: Dict[str, Any] = {} layer_AVDatasets = AV.generate_dataset_activations( self.activation_dir, self.module, self.model_id, self.layers, DataLoader(self.influence_src_dataset, self.batch_size, shuffle=False), identifier="src", load_from_disk=load_src_from_disk, return_activations=True, ) assert layer_AVDatasets is not None and not isinstance( layer_AVDatasets, AV.AVDataset ) layer_modules = [ common._get_module_from_name(self.module, layer) for layer in self.layers ] test_activations = LayerActivation(self.module, layer_modules).attribute( inputs, additional_forward_args ) minmax = self.similarity_direction == "max" # av_inputs shape: (inputs_batch_size, *) e.g. (inputs_batch_size, N, C, H, W) # av_src shape: (self.batch_size, *) e.g. (self.batch_size, N, C, H, W) test_activations = ( test_activations if len(self.layers) > 1 else [test_activations] ) for i, (layer, layer_AVDataset) in enumerate( zip(self.layers, layer_AVDatasets) ): topk_val, topk_idx = torch.Tensor(), torch.Tensor().long() zero_acts = torch.Tensor().long() av_inputs = test_activations[i] src_loader = DataLoader(layer_AVDataset) for j, av_src in enumerate(src_loader): av_src = av_src.squeeze(0) similarity = self.similarity_metric(av_inputs, av_src) msg = ( "Output of custom similarity does not meet required dimensions. " f"Your output has shape {similarity.shape}.\nPlease ensure the " "output shape matches (inputs_batch_size, src_dataset_batch_size), " f"which should be {(inputs_batch_size, self.batch_size)}." ) assert similarity.shape == (inputs_batch_size, av_src.shape[0]), msg if hasattr(self, "replace_nan"): idx = (similarity == self.replace_nan).nonzero() zero_acts = torch.cat((zero_acts, idx)) r""" TODO: For models that can have tuples as activations, we should allow similarity metrics to accept tuples, support topk selection. """ topk_batch = min(top_k, self.batch_size) values, indices = torch.topk( similarity, topk_batch, dim=1, largest=minmax ) indices += int(j * self.batch_size) topk_val = torch.cat((topk_val, values), dim=1) topk_idx = torch.cat((topk_idx, indices), dim=1) # can modify how often to sort for efficiency? minor sort_idx = torch.argsort(topk_val, dim=1, descending=minmax) topk_val = torch.gather(topk_val, 1, sort_idx[:, :top_k]) topk_idx = torch.gather(topk_idx, 1, sort_idx[:, :top_k]) influences[layer] = (topk_idx, topk_val) if torch.numel(zero_acts != 0): zero_warning = ( f"Layer {layer} has zero-vector activations for some inputs. This " "may cause undefined behavior for cosine similarity. The indices " "for the offending inputs will be included under the key " f"'zero_acts-{layer}' in the output dictionary. Indices are " "returned as a tensor with [inputs_idx, src_dataset_idx] pairs " "which may have corrupted similarity scores." ) warnings.warn( zero_warning, RuntimeWarning, stacklevel=1, ) key = "-".join(["zero_acts", layer]) influences[key] = zero_acts return influences