Influential Examples

DataInfluence

class captum.influence.DataInfluence[source]

An abstract class to define model data influence skeleton.

classmethod get_name()[source]

Create readable class name. Due to the nature of the names of TracInCPBase subclasses, simply returns the class name. For example, for a class called TracInCP, we return the string TracInCP.

Returns:

a readable class name

Return type:

name (str)

abstract influence(inputs=None, **kwargs)[source]
Parameters:
  • inputs (Any) – Batch of examples for which influential instances are computed. They are passed to the forward_func. If inputs if a tensor or tuple of tensors, the first dimension of a tensor corresponds to the batch dimension.

  • **kwargs (Any) – Additional key-value arguments that are necessary for specific implementation of DataInfluence abstract class.

Returns:

We do not add restrictions on the return type for now,

though this may change in the future.

Return type:

influences (Any)

SimilarityInfluence

class captum.influence.SimilarityInfluence(module, layers, influence_src_dataset, activation_dir, model_id='', similarity_metric=cosine_similarity, similarity_direction='max', batch_size=1, **kwargs)[source]
Parameters:
  • module (torch.nn.Module) – An instance of pytorch model. This model should define all of its layers as attributes of the model.

  • layers (str or list[str]) – The fully qualified layer(s) for which the activation vectors are computed.

  • influence_src_dataset (torch.utils.data.Dataset) – PyTorch Dataset that is used to create a PyTorch Dataloader to iterate over the dataset and its labels. This is the dataset for which we will be seeking for influential instances. In most cases this is the training dataset.

  • activation_dir (str) – The directory of the path to store and retrieve activation computations. Best practice would be to use an absolute path.

  • model_id (str) – The name/version of the model for which layer activations are being computed. Activations will be stored and loaded under the subdirectory with this name if provided.

  • similarity_metric (Callable) –

    This is a callable function that computes a similarity metric between two representations. For example, the representations pair could be from the training and test sets.

    This function must adhere to certain standards. The inputs should be torch Tensors with shape (batch_size_i/j, feature dimensions). The output Tensor should have shape (batch_size_i, batch_size_j) with scalar values corresponding to the similarity metric used for each pairwise combination from the two batches.

    For example, suppose we use batch_size_1 = 16 for iterating through influence_src_dataset, and for the inputs argument we pass in a Tensor with 3 examples, i.e. batch_size_2 = 3. Also, suppose that our inputs and intermediate activations throughout the model will have dimension (N, C, H, W). Then, the feature dimensions should be flattened within this function. For example:

    >>> av_test.shape
    torch.Size([3, N, C, H, W])
    >>> av_src.shape
    torch.Size([16, N, C, H, W])
    >>> av_test = torch.view(av_test.shape[0], -1)
    >>> av_test.shape
    torch.Size([3, N x C x H x W])
    

    and similarly for av_src. The similarity_metric should then use these flattened tensors to return the pairwise similarity matrix. For example, similarity_metric(av_test, av_src) should return a tensor of shape (3, 16).

  • batch_size (int) – Batch size for iterating through influence_src_dataset.

  • **kwargs (Any) – Additional key-value arguments that are necessary for specific implementation of DataInfluence abstract class.

influence(inputs, top_k=1, additional_forward_args=None, load_src_from_disk=True, **kwargs)[source]
Parameters:
  • inputs (Tensor or tuple[Tensor, ...]) – Batch of examples for which influential instances are computed. They are passed to the forward_func. The first dimension in inputs tensor or tuple of tensors corresponds to the batch size. A tuple of tensors is only passed in if thisis the input form that module accepts.

  • top_k (int) – The number of top-matching activations to return

  • additional_forward_args (Any, optional) – Additional arguments that will be passed to forward_func after inputs.

  • load_src_from_disk (bool) – Loads activations for influence_src_dataset where possible. Setting to False would force regeneration of activations.

  • load_input_from_disk (bool) – Regenerates activations for inputs by default and removes previous inputs activations that are flagged with inputs_id. Setting to True will load prior matching inputs activations. Note that this could lead to unexpected behavior if inputs_id is not configured properly and activations are loaded for a different, prior inputs.

  • inputs_id (str) – Used to identify inputs for loading activations.

  • **kwargs (Any) – Additional key-value arguments that are necessary for specific implementation of DataInfluence abstract class.

Returns:

Returns the influential instances retrieved from

influence_src_dataset for each test example represented through a tensor or a tuple of tensor in inputs. Returned influential examples are represented as dict, with keys corresponding to the layer names passed in layers. Each value in the dict is a tuple containing the indices and values for the top k similarities from influence_src_dataset by the chosen metric. The first value in the tuple corresponds to the indices corresponding to the top k most similar examples, and the second value is the similarity score. The batch dimension corresponds to the batch dimension of inputs. If inputs.shape[0] == 5, then dict[layer_name][0].shape[0] == 5. These tensors will be of shape (inputs.shape[0], top_k).

Return type:

influences (dict)

TracInCPBase

class captum.influence.TracInCPBase(model, train_dataset, checkpoints, checkpoints_load_func=_load_flexible_state_dict, loss_fn=None, batch_size=1, test_loss_fn=None)[source]

To implement the influence method, classes inheriting from TracInCPBase will separately implement the private _self_influence, _get_k_most_influential, and _influence methods. The public influence method is a wrapper for these private methods.

Parameters:
  • model (torch.nn.Module) – An instance of pytorch model. This model should define all of its layers as attributes of the model.

  • train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader) – In the influence method, we compute the influence score of training examples on examples in a test batch. This argument represents the training dataset containing those training examples. In order to compute those influence scores, we will create a Pytorch DataLoader yielding batches of training examples that is then used for processing. If this argument is already a Pytorch Dataloader, that DataLoader can be directly used for processing. If it is instead a Pytorch Dataset, we will create a DataLoader using it, with batch size specified by batch_size. For efficiency purposes, the batch size of the DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if train_dataset is a Dataset, batch_size should be large. If train_dataset was already a DataLoader to begin with, it should have been constructed to have a large batch size. It is assumed that the Dataloader (regardless of whether it is created from a Pytorch Dataset or not) yields tuples. For a batch that is yielded, of length L, it is assumed that the forward function of model accepts L-1 arguments, and the last element of batch is the label. In other words, model(*batch[:-1]) gives the output of model, and batch[-1] are the labels for the batch.

  • checkpoints (str, list[str], or Iterator) – Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which returns objects from which to load checkpoints.

  • checkpoints_load_func (Callable, optional) – The function to load a saved checkpoint into a model to update its parameters, and get the learning rate if it is saved. By default uses a utility to load a model saved as a state dict. Default: _load_flexible_state_dict

  • loss_fn (Callable, optional) – The loss function applied to model. Default: None

  • batch_size (int or None, optional) – Batch size of the DataLoader created to iterate through train_dataset, if it is a Dataset. batch_size should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of TracInCPBase will detail the size of the intermediate quantities. batch_size must be an int if train_dataset is a Dataset. If train_dataset is a DataLoader, then batch_size is ignored as an argument. Default: 1

  • test_loss_fn (Callable, optional) – In some cases, one may want to use a separate loss functions for training examples, i.e. those in train_dataset, and for test examples, i.e. those represented by the inputs and targets arguments to the influence method. For example, if one wants to calculate the influence score of a training example on a test example’s prediction for a fixed class, test_loss_fn could map from the logits for all classes to the logits for a fixed class. test_loss_fn needs to satisfy the same constraints as loss_fn. If not provided, the loss function for test examples is assumed to be the same as the loss function for training examples, i.e. loss_fn. Default: None

classmethod get_name()[source]

Create readable class name. Due to the nature of the names of TracInCPBase subclasses, simplies returns the class name. For example, for a class called TracInCP, we return the string TracInCP.

Returns:

a readable class name

Return type:

name (str)

abstract influence(inputs, k=None, proponents=True, unpack_inputs=True, show_progress=False)[source]

This is the key method of this class, and can be run in 2 different modes, where the mode that is run depends on the arguments passed to this method:

  • influence score mode: This mode is used if k is None. This mode computes the influence score of every example in training dataset train_dataset on every example in the test batch represented by inputs.

  • k-most influential mode: This mode is used if k is not None, and an int. This mode computes the proponents or opponents of every example in the test batch represented by inputs. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the indices in the training dataset train_dataset of the training examples with the k highest (resp. lowest) influence scores on the test example. Proponents are computed if proponents is True. Otherwise, opponents are computed. For each test example, this method also returns the actual influence score of each proponent (resp. opponent) on the test example.

Parameters:
  • inputs (tuple) – inputs is the test batch and is a tuple of any, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset - please see its documentation in __init__ for more details on the assumed structure of a batch.

  • k (int, optional) – If not provided or None, the influence score mode will be run. Otherwise, the k-most influential mode will be run, and k is the number of proponents / opponents to return per example in the test batch. Default: None

  • proponents (bool, optional) – Whether seeking proponents (proponents=True) or opponents (proponents=False), if running in k-most influential mode. Default: True

  • show_progress (bool, optional) – For all modes, computation of results requires “training dataset computations”: computations for each batch in the training dataset train_dataset, which may take a long time. If show_progress is true, the progress of “training dataset computations” will be displayed. In particular, the number of batches for which computations have been performed will be displayed. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

Return type:

Union[Tensor, KMostInfluentialResults]

Returns:

The return value of this method depends on which mode is run.

  • influence score mode: if this mode is run (k is None), returns a 2D tensor influence_scores of shape (input_size, train_dataset_size), where input_size is the number of examples in the test batch, and train_dataset_size is the number of examples in training dataset train_dataset. In other words, influence_scores[i][j] is the influence score of the j-th example in train_dataset on the i-th example in the test batch.

  • k-most influential mode: if this mode is run (k is an int), returns a namedtuple (indices, influence_scores). indices is a 2D tensor of shape (input_size, k), where input_size is the number of examples in the test batch. If computing proponents (resp. opponents), indices[i][j] is the index in training dataset train_dataset of the example with the j-th highest (resp. lowest) influence score (out of the examples in train_dataset) on the i-th example in the test dataset. influence_scores contains the corresponding influence scores. In particular, influence_scores[i][j] is the influence score of example indices[i][j] in train_dataset on example i in the test batch represented by inputs.

abstract self_influence(inputs=None, show_progress=False)[source]

If inputs is not specified calculates the self influence scores for the training dataset train_dataset. Otherwise, computes self influence scores for the examples in inputs, which is either a single batch or a Pytorch DataLoader that yields batches. Therefore, in this case, the computed self influence scores are not for the examples in training dataset train_dataset. Note that if inputs is a single batch, this will call model on that single batch, and if inputs yields batches, this will call model on each batch that is yielded. Therefore, please ensure that for both cases, the batch(es) that model is called with are not too large, so that there will not be an out-of-memory error.

Parameters:
  • inputs (tuple or DataLoader, optional) – This specifies the dataset for which self influence scores will be computed. Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of type any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset. Please see documentation for the train_dataset argument to TracInCP.__init__ for more details on the assumed structure of a batch. If not provided or None, self influence scores will be computed for training dataset train_dataset, which yields batches satisfying the above assumptions. Default: None.

  • show_progress (bool, optional) – Computation of self influence scores can take a long time if inputs represents many examples. If show_progress is true, the progress of this computation will be displayed. In more detail, this computation will iterate over all checkpoints (provided as the checkpoints initialization argument) in an outer loop, and iterate over all batches that inputs represents in an inner loop. Therefore, the total number of (checkpoint, batch) combinations that need to be iterated over is (# of checkpoints x # of batches that inputs represents). If show_progress is True, the total progress of both the outer iteration over checkpoints and the inner iteration over batches is displayed. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

Returns:

This is a 1D tensor containing the self

influence scores of all examples in inputs, regardless of whether it represents a single batch or a DataLoader that yields batches.

Return type:

self_influence_scores (Tensor)

TracInCP

class captum.influence.TracInCP(model, train_dataset, checkpoints, checkpoints_load_func=_load_flexible_state_dict, layers=None, loss_fn=None, batch_size=1, test_loss_fn=None, sample_wise_grads_per_batch=False)[source]
Parameters:
  • model (torch.nn.Module) – An instance of pytorch model. This model should define all of its layers as attributes of the model.

  • train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader) – In the influence method, we compute the influence score of training examples on examples in a test batch. This argument represents the training dataset containing those training examples. In order to compute those influence scores, we will create a Pytorch DataLoader yielding batches of training examples that is then used for processing. If this argument is already a Pytorch Dataloader, that DataLoader can be directly used for processing. If it is instead a Pytorch Dataset, we will create a DataLoader using it, with batch size specified by batch_size. For efficiency purposes, the batch size of the DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if train_dataset is a Dataset, batch_size should be large. If train_dataset was already a DataLoader to begin with, it should have been constructed to have a large batch size. It is assumed that the Dataloader (regardless of whether it is created from a Pytorch Dataset or not) yields tuples. For a batch that is yielded, of length L, it is assumed that the forward function of model accepts L-1 arguments, and the last element of batch is the label. In other words, model(*batch[:-1]) gives the output of model, and batch[-1] are the labels for the batch.

  • checkpoints (str, list[str], or Iterator) – Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which returns objects from which to load checkpoints.

  • checkpoints_load_func (Callable, optional) – The function to load a saved checkpoint into a model to update its parameters, and get the learning rate if it is saved. By default uses a utility to load a model saved as a state dict. Default: _load_flexible_state_dict

  • layers (list[str] or None, optional) – A list of layer names for which gradients should be computed. If layers is None, gradients will be computed for all layers. Otherwise, they will only be computed for the layers specified in layers. Default: None

  • loss_fn (Callable, optional) – The loss function applied to model. There are two options for the return type of loss_fn. First, loss_fn can be a “per-example” loss function - returns a 1D Tensor of losses for each example in a batch. nn.BCELoss(reduction=”none”) would be an “per-example” loss function. Second, loss_fn can be a “reduction” loss function that reduces the per-example losses, in a batch, and returns a single scalar Tensor. For this option, the reduction must be the sum or the mean of the per-example losses. For instance, nn.BCELoss(reduction=”sum”) is acceptable. Note for the first option, the sample_wise_grads_per_batch argument must be False, and for the second option, sample_wise_grads_per_batch must be True. Also note that for the second option, if loss_fn has no “reduction” attribute, the implementation assumes that the reduction is the sum of the per-example losses. If this is not the case, i.e. the reduction is the mean, please set the “reduction” attribute of loss_fn to “mean”, i.e. loss_fn.reduction = “mean”. Default: None

  • batch_size (int or None, optional) – Batch size of the DataLoader created to iterate through train_dataset, if it is a Dataset. batch_size should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of TracInCPBase will detail the size of the intermediate quantities. batch_size must be an int if train_dataset is a Dataset. If train_dataset is a DataLoader, then batch_size is ignored as an argument. Default: 1

  • test_loss_fn (Callable, optional) – In some cases, one may want to use a separate loss functions for training examples, i.e. those in train_dataset, and for test examples, i.e. those represented by the inputs and targets arguments to the influence method. For example, if one wants to calculate the influence score of a training example on a test example’s prediction for a fixed class, test_loss_fn could map from the logits for all classes to the logits for a fixed class. test_loss_fn needs satisfy the same constraints as loss_fn. Thus, the same checks that we apply to loss_fn are also applied to test_loss_fn, if the latter is provided. Note that the constraints on both loss_fn and test_loss_fn both depend on sample_wise_grads_per_batch. This means loss_fn and test_loss_fn must either both be “per-example” loss functions, or both be “reduction” loss functions. If not provided, the loss function for test examples is assumed to be the same as the loss function for training examples, i.e. loss_fn. Default: None

  • sample_wise_grads_per_batch (bool, optional) – PyTorch’s native gradient computations w.r.t. model parameters aggregates the results for a batch and does not allow to access sample-wise gradients w.r.t. model parameters. This forces us to iterate over each sample in the batch if we want sample-wise gradients which is computationally inefficient. We offer an implementation of batch-wise gradient computations w.r.t. to model parameters which is computationally more efficient. This implementation can be enabled by setting the sample_wise_grad_per_batch argument to True, and should be enabled if and only if the loss_fn argument is a “reduction” loss function. For example, nn.BCELoss(reduction=”sum”) would be a valid loss_fn if this implementation is enabled (see documentation for loss_fn for more details). Note that our current implementation enables batch-wise gradient computations only for a limited number of PyTorch nn.Modules: Conv2D and Linear. This list will be expanded in the near future. Therefore, please do not enable this implementation if gradients will be computed for other kinds of layers. Default: False

compute_intermediate_quantities(inputs, aggregate=False)[source]

Computes “embedding” vectors for all examples in a single batch, or a Dataloader that yields batches. These embedding vectors are constructed so that the influence score of a training example on a test example is simply the dot-product of their corresponding vectors. Allowing a DataLoader yielding batches to be passed in (as opposed to a single batch) gives the potential to improve efficiency, because we load each checkpoint only once in this method call. Thus if a DataLoader yielding batches is passed in, this reduces the total number of times each checkpoint is loaded for a dataset, compared to if a single batch is passed in. The reason we do not just increase the batch size is that for large models, large batches do not fit in memory.

If aggregate is True, the sum of the vectors for all examples is returned, instead of the vectors for each example. This can be useful for computing the influence of a given training example on the total loss over a validation dataset, because due to properties of the dot-product, this influence is the dot-product of the training example’s vector with the sum of the vectors in the validation dataset. Also, by doing the sum aggregation within this method as opposed to outside of it (by computing all vectors for the validation dataset, then taking the sum) allows memory usage to be reduced.

Parameters:
  • inputs (Tuple, or DataLoader) – Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and and batch[-1] are the labels, if any. Here, model is model provided in initialization. This is the same assumption made for each batch yielded by training dataset train_dataset.

  • aggregate (bool) – Whether to return the sum of the vectors for all examples, as opposed to vectors for each example.

Returns:

A tensor of dimension

(N, D * C). Here, N is the total number of examples in inputs if aggregate is False, and 1, otherwise (so that a 2D tensor is always returned). C is the number of checkpoints passed as the checkpoints argument of TracInCP.__init__, and each row represents the vector for an example. Regarding D: Let I be the dimension of the output of the last fully-connected layer times the dimension of the input of the last fully-connected layer. If self.projection_dim is specified in initialization, D = min(I * C, self.projection_dim * C). Otherwise, D = I * C. In summary, if self.projection_dim is None, the dimension of each vector will be determined by the size of the input and output of the last fully-connected layer of model. Otherwise, self.projection_dim must be an int, and random projection will be performed to ensure that the vector is of dimension no more than self.projection_dim * C. self.projection_dim corresponds to the variable d in the top of page 15 of the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf.

Return type:

intermediate_quantities (Tensor)

influence(inputs, k=None, proponents=True, show_progress=False, aggregate=False)[source]

This is the key method of this class, and can be run in 2 different modes, where the mode that is run depends on the arguments passed to this method. Below, we describe the 2 modes, when aggregate is false:

  • influence score mode: This mode is used if k is None. This mode computes the influence score of every example in training dataset train_dataset on every example in the test dataset represented by inputs.

  • k-most influential mode: This mode is used if k is not None, and an int. This mode computes the proponents or opponents of every example in the test dataset represented by inputs. In particular, for each test example in the test dataset, this mode computes its proponents (resp. opponents), which are the indices in the training dataset train_dataset of the training examples with the k highest (resp. lowest) influence scores on the test example. Proponents are computed if proponents is True. Otherwise, opponents are computed. For each test example, this method also returns the actual influence score of each proponent (resp. opponent) on the test example.

When aggregate is True, this method computes “aggregate” influence scores, which for a given training example, is the sum of its influence scores over all examples in the test dataset. Below, we describe the 2 modes, when aggregate is True:

  • influence score mode: This mode is used if k is None. This mode computes the aggregate influence score of each example in training dataset train_dataset on the test dataset.

  • k-most influential mode: This mode is used if k is not None, and an int. This mode computes the “aggregate” proponents (resp. opponents), which are the indices in the training dataset train_dataset of the examples with the k highest (resp. lowest) aggregate influence scores on the test dataset. Proponents are computed if proponents is True. Otherwise, opponents are computed. This method also returns the actual aggregate influence scores of each proponent (resp. opponent) on the test dataset.

Parameters:
  • inputs (Tuple, or DataLoader) – Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and and batch[-1] are the labels, if any. Here, model is model provided in initialization. This is the same assumption made for each batch yielded by training dataset train_dataset. Please see documentation for the train_dataset argument to TracInCPFastRandProj.__init__ for more details on the assumed structure of a batch.

  • k (int, optional) – If not provided or None, the influence score mode will be run. Otherwise, the k-most influential mode will be run, and k is the number of proponents / opponents to return per example in the test batch. Default: None

  • proponents (bool, optional) – Whether seeking proponents (proponents=True) or opponents (proponents=False), if running in k-most influential mode. Default: True

  • show_progress (bool, optional) – For all modes, computation of results requires “training dataset computations”: computations for each batch in the training dataset train_dataset, which may take a long time. If show_progress is true, the progress of “training dataset computations” will be displayed. In particular, the number of batches for which computations have been performed will be displayed. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

  • aggregate (bool, optional) – If true, return “aggregate” influence scores or examples with the highest / lowest aggregate influence scores on the test dataset, depending on the mode.

Return type:

Union[Tensor, KMostInfluentialResults]

Returns:

The return value of this method depends on which mode is run, and whether aggregate is True of False.

Below are the return values for the 2 modes, when aggregate is False:

  • influence score mode: if this mode is run (k is None), returns a 2D tensor influence_scores of shape (input_size, train_dataset_size), where input_size is the number of examples in the test dataset, and train_dataset_size is the number of examples in training dataset train_dataset. In other words, influence_scores[i][j] is the influence score of the j-th example in train_dataset on the i-th example in the test dataset.

  • k-most influential mode: if this mode is run (k is an int), returns a namedtuple (indices, influence_scores). indices is a 2D tensor of shape (input_size, k), where input_size is the number of examples in the test dataset. If computing proponents (resp. opponents), indices[i][j] is the index in training dataset train_dataset of the example with the j-th highest (resp. lowest) influence score (out of the examples in train_dataset) on the i-th example in the test dataset. influence_scores contains the corresponding influence scores. In particular, influence_scores[i][j] is the influence score of example indices[i][j] in train_dataset on example i in the test dataset represented by inputs.

Below are the return values for the 2 modes, when aggregate is True:

  • influence score mode: if this mode is run (k is None), returns a 2D tensor influence_scores of shape (1, train_dataset_size), where influence_scores[0][j] is the aggregate influence score of the `j-th example in train_dataset on the test dataset.

  • k-most influential mode: if this mode is run (k is an int), returns a namedtuple (indices, influence_scores). indices is a 2D tensor of shape (1, k). If computing proponents (resp. opponents), indices[0][j] is the index in training dataset train_dataset of the example with the j-th highest (resp. lowest) aggregate influence score on the test dataset. influence_scores contains the corresponding aggregate influence scores. In particular, influence_scores[0][j] is the aggregate influence score of example indices[0][j] on the test dataset.

self_influence(inputs=None, show_progress=False, outer_loop_by_checkpoints=False)[source]

Computes self influence scores for the examples in inputs, which is either a single batch or a Pytorch DataLoader that yields batches. If inputs is not specified or None calculates self influence score for the training dataset train_dataset. Note that if inputs is a single batch, this will call model on that single batch, and if inputs yields batches, this will call model on each batch that is yielded. Therefore, please ensure that for both cases, the batch(es) that model is called with are not too large, so that there will not be an out-of-memory error. Internally, this computation requires iterating both over the batches in inputs, as well as different model checkpoints. There are two ways this iteration can be done. If outer_loop_by_checkpoints is False, the outer iteration will be over batches, and the inner iteration will be over checkpoints. This has the pro that displaying the progress of the computation is more intuitive, involving displaying the number of batches for which self influence scores have been computed. If outer_loop_by_checkpoints is True, the outer iteration will be over checkpoints, and the inner iteration will be over batches. This has the pro that the checkpoints do not need to be loaded for each batch. For large models, loading checkpoints can be time-intensive.

Parameters:
  • inputs (tuple or DataLoader, optional) – This specifies the dataset for which self influence scores will be computed. Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of type any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset. Please see documentation for the train_dataset argument to TracInCP.__init__ for more details on the assumed structure of a batch. If not provided or None, self influence scores will be computed for training dataset train_dataset, which yields batches satisfying the above assumptions. Default: None.

  • show_progress (bool, optional) – Computation of self influence scores can take a long time if inputs represents many examples. If show_progress`is true, the progress of this computation will be displayed. In more detail, if `outer_loop_by_checkpoints is False, this computation will iterate over all batches in an outer loop. Thus if show_progress is True, the number of batches for which self influence scores have been computed will be displayed. If outer_loop_by_checkpoints is True, this computation will iterate over all checkpoints (provided as the checkpoints initialization argument) in an outer loop, and iterate over all batches that inputs represents in an inner loop. Thus if show_progress is True, the progress of both the outer iteration and the inner iterations will be displayed. To show progress, it will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

  • outer_loop_by_checkpoints (bool, optional) – If performing an outer iteration over checkpoints; see method description for more details. Default: False

Return type:

Tensor

test_reduction_type

Either restore model state after done (would have to place functionality within influence to restore after every influence call)? or make a copy so that changes to grad_requires aren’t persistent after using TracIn.

Type:

TODO

TracInCPFast

class captum.influence.TracInCPFast(model, final_fc_layer, train_dataset, checkpoints, checkpoints_load_func=_load_flexible_state_dict, loss_fn=None, batch_size=1, test_loss_fn=None, vectorize=False)[source]

In Appendix F, Page 14 of the TracIn paper, they show that the calculation of the influence score of between a test example x’ and a training example x, can be computed much more quickly than naive back-propagation in the special case when considering only gradients in the last fully-connected layer. This class computes influence scores for that special case. Note that the computed influence scores are exactly the same as when naive back-propagation is used - there is no loss in accuracy.

In more detail regarding the influence score computation: let :math`x` and :math`nabla_y f(y)` be the input and output-gradient of the last fully-connected layer, respectively, for a training example. Similarly, let :math`x’` and :math`nabla_{y’} f(y’)` be the corresponding quantities for a test example. Then, the influence score of the training example on the test example is the sum of the contribution from each checkpoint. The contribution from a given checkpoint is :math`(x^T x’)(nabla_y f(y)^T nabla_{y’} f(y’))`.

Parameters:
  • model (torch.nn.Module) – An instance of pytorch model. This model should define all of its layers as attributes of the model.

  • final_fc_layer (torch.nn.Module) – The last fully connected layer in the network for which gradients will be approximated via fast random projection method.

  • train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader) – In the influence method, we compute the influence score of training examples on examples in a test batch. This argument represents the training dataset containing those training examples. In order to compute those influence scores, we will create a Pytorch DataLoader yielding batches of training examples that is then used for processing. If this argument is already a Pytorch Dataloader, that DataLoader can be directly used for processing. If it is instead a Pytorch Dataset, we will create a DataLoader using it, with batch size specified by batch_size. For efficiency purposes, the batch size of the DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if train_dataset is a Dataset, batch_size should be large. If train_dataset was already a DataLoader to begin with, it should have been constructed to have a large batch size. It is assumed that the Dataloader (regardless of whether it is created from a Pytorch Dataset or not) yields tuples. For a batch that is yielded, of length L, it is assumed that the forward function of model accepts L-1 arguments, and the last element of batch is the label. In other words, model(*batch[:-1]) gives the output of model, and batch[-1] are the labels for the batch.

  • checkpoints (str, list[str], or Iterator) – Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which returns objects from which to load checkpoints.

  • checkpoints_load_func (Callable, optional) – The function to load a saved checkpoint into a model to update its parameters, and get the learning rate if it is saved. By default uses a utility to load a model saved as a state dict. Default: _load_flexible_state_dict

  • loss_fn (Callable, optional) – The loss function applied to model. loss_fn must be a “reduction” loss function that reduces the per-example losses in a batch, and returns a single scalar Tensor. Furthermore, the reduction must be the sum or the mean of the per-example losses. For instance, nn.BCELoss(reduction=”sum”) is acceptable. Also note that if loss_fn has no “reduction” attribute, the implementation assumes that the reduction is the sum of the per-example losses. If this is not the case, i.e. the reduction is the mean, please set the “reduction” attribute of loss_fn to “mean”, i.e. loss_fn.reduction = “mean”. Default: None

  • batch_size (int or None, optional) – Batch size of the DataLoader created to iterate through train_dataset, if it is a Dataset. batch_size should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of TracInCPBase will detail the size of the intermediate quantities. batch_size must be an int if train_dataset is a Dataset. If train_dataset is a DataLoader, then batch_size is ignored as an argument. Default: 1

  • test_loss_fn (Callable, optional) – In some cases, one may want to use a separate loss functions for training examples, i.e. those in train_dataset, and for test examples, i.e. those represented by the inputs and targets arguments to the influence method. For example, if one wants to calculate the influence score of a training example on a test example’s prediction for a fixed class, test_loss_fn could map from the logits for all classes to the logits for a fixed class. test_loss_fn needs satisfy the same constraints as loss_fn. Thus, the same checks that we apply to loss_fn are also applied to test_loss_fn, if the latter is provided. If not provided, the loss function for test examples is assumed to be the same as the loss function for training examples, i.e. loss_fn. Default: None

  • vectorize (bool, optional) – Flag to use experimental vectorize functionality for torch.autograd.functional.jacobian. Default: False

influence(inputs, k=None, proponents=True, show_progress=False)[source]

This is the key method of this class, and can be run in 2 different modes, where the mode that is run depends on the arguments passed to this method:

  • influence score mode: This mode is used if k is None. This mode computes the influence score of every example in training dataset train_dataset on every example in the test batch represented by inputs.

  • k-most influential mode: This mode is used if k is not None, and an int. This mode computes the proponents or opponents of every example in the test batch represented by inputs. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the indices in the training dataset train_dataset of the training examples with the k highest (resp. lowest) influence scores on the test example. Proponents are computed if proponents is True. Otherwise, opponents are computed. For each test example, this method also returns the actual influence score of each proponent (resp. opponent) on the test example.

Parameters:
  • inputs (tuple or DataLoader) – inputs is the test batch and is a tuple of any, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset - please see its documentation in __init__ for more details on the assumed structure of a batch.

  • k (int, optional) – If not provided or None, the influence score mode will be run. Otherwise, the k-most influential mode will be run, and k is the number of proponents / opponents to return per example in the test batch. Default: None

  • proponents (bool, optional) – Whether seeking proponents (proponents=True) or opponents (proponents=False), if running in k-most influential mode. Default: True

  • show_progress (bool, optional) – For all modes, computation of results requires “training dataset computations”: computations for each batch in the training dataset train_dataset, which may take a long time. If show_progress is true, the progress of “training dataset computations” will be displayed. In particular, the number of batches for which computations have been performed will be displayed. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

Return type:

Union[Tensor, KMostInfluentialResults]

Returns:

The return value of this method depends on which mode is run.

  • influence score mode: if this mode is run (k is None), returns a 2D tensor influence_scores of shape (input_size, train_dataset_size), where input_size is the number of examples in the test batch, and train_dataset_size is the number of examples in training dataset train_dataset. In other words, influence_scores[i][j] is the influence score of the j-th example in train_dataset on the i-th example in the test batch.

  • k-most influential mode: if this mode is run (k is an int), returns a namedtuple (indices, influence_scores). indices is a 2D tensor of shape (input_size, k), where input_size is the number of examples in the test batch. If computing proponents (resp. opponents), indices[i][j] is the index in training dataset train_dataset of the example with the j-th highest (resp. lowest) influence score (out of the examples in train_dataset) on the i-th example in the test batch. influence_scores contains the corresponding influence scores. In particular, influence_scores[i][j] is the influence score of example indices[i][j] in train_dataset on example i in the test batch represented by inputs.

self_influence(inputs=None, show_progress=False, outer_loop_by_checkpoints=False)[source]

Computes self influence scores for the examples in inputs, which is either a single batch or a Pytorch DataLoader that yields batches. If inputs is not specified or None calculates self influence score for the training dataset train_dataset. Note that if inputs is a single batch, this will call model on that single batch, and if inputs yields batches, this will call model on each batch that is yielded. Therefore, please ensure that for both cases, the batch(es) that model is called with are not too large, so that there will not be an out-of-memory error. Internally, this computation requires iterating both over the batches in inputs, as well as different model checkpoints. There are two ways this iteration can be done. If outer_loop_by_checkpoints is False, the outer iteration will be over batches, and the inner iteration will be over checkpoints. This has the pro that displaying the progress of the computation is more intuitive, involving displaying the number of batches for which self influence scores have been computed. If outer_loop_by_checkpoints is True, the outer iteration will be over checkpoints, and the inner iteration will be over batches. This has the pro that the checkpoints do not need to be loaded for each batch. For large models, loading checkpoints can be time-intensive.

Parameters:
  • inputs (tuple or DataLoader, optional) – This specifies the dataset for which self influence scores will be computed. Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of type any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset. Please see documentation for the train_dataset argument to TracInCP.__init__ for more details on the assumed structure of a batch. If not provided or None, self influence scores will be computed for training dataset train_dataset, which yields batches satisfying the above assumptions. Default: None.

  • show_progress (bool, optional) – Computation of self influence scores can take a long time if inputs represents many examples. If show_progress`is true, the progress of this computation will be displayed. In more detail, if `outer_loop_by_checkpoints is False, this computation will iterate over all batches in an outer loop. Thus if show_progress is True, the number of batches for which self influence scores have been computed will be displayed. If outer_loop_by_checkpoints is True, this computation will iterate over all checkpoints (provided as the checkpoints initialization argument) in an outer loop, and iterate over all batches that inputs represents in an inner loop. Thus if show_progress is True, the progress of both the outer iteration and the inner iterations will be displayed. To show progress, it will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

  • outer_loop_by_checkpoints (bool, optional) – If performing an outer iteration over checkpoints; see method description for more details. Default: False

Return type:

Tensor

TracInCPFastRandProj

class captum.influence.TracInCPFastRandProj(model, final_fc_layer, train_dataset, checkpoints, checkpoints_load_func=_load_flexible_state_dict, loss_fn=None, batch_size=1, test_loss_fn=None, vectorize=False, nearest_neighbors=None, projection_dim=None, seed=0)[source]

A version of TracInCPFast which is optimized for “interactive” calls to influence for the purpose of calculating proponents / opponents, or influence scores. “Interactive” means there will be multiple calls to influence, with each call for a different batch of test examples, and subsequent calls rely on the results of previous calls. The implementation in this class has been optimized so that each call to influence is fast, so that it can be used for interactive analysis. This class should only be used for interactive use cases. It should not be used if influence will only be called once, because to enable fast calls to influence, time and memory intensive preprocessing is required in __init__. Furthermore, it should not be used to calculate self influence scores - TracInCPFast should be used instead for that purpose. To enable interactive analysis, this implementation computes and saves “embedding” vectors for all training examples in train_dataset. Crucially, the influence score of a training example on a test example is simply the dot-product of their corresponding vectors, and proponents / opponents can be found by first storing vectors for training examples in a nearest-neighbor data structure, and then finding the nearest-neighbors for a test example in terms of dot-product (see appendix F of the TracIn paper). This class should only be used if calls to influence to obtain proponents / opponents or influence scores will be made in an “interactive” manner, and there is sufficient memory to store vectors for the entire train_dataset. This is because in order to enable interactive analysis, this implementation incures overhead in __init__ to setup the nearest-neighbors data structure, which is both time and memory intensive, as vectors corresponding to all training examples needed to be stored. To reduce memory usage, this implementation enables random projections of those vectors. Note that the influence scores computed with random projections are less accurate, though correct in expectation.

In more detail regarding the “embedding” vectors - the influence of a training example on a test example, when only considering gradients in the last fully-connected layer, the sum of the contribution from each checkpoint. The contribution from a given checkpoint is :math`(x^T x’)(nabla_y f(y)^T nabla_{y’} f(y’))`, using the notation in the description of TracInCPFast. As is, this is not a dot-product of 2 vectors. However, we can rewrite that contribution as :math`(x nabla_y f(y)^T) dot (x’ f(y’)^T)`. Both terms in this product are 2D matrices, as they are outer products, and the “product” is actually a dot-product, treating both matrices as vectors. Therefore, for a given checkpoint, its contribution to the “embedding” of an example is just the outer-product :math`(x nabla_y f(y)^T)`, flattened. Furthemore, to reduce the dimension of this contribution, we can right-multiply and left-multiply the outer-product with two separate projection matrices. These transform :math`nabla_y f(y)` and :math`x` to lower dimensional vectors. While the dimension of these two lower dimensional vectors do not necessarily need to be the same, in our implementation, we let them be the same, both equal to the square root of the desired projection dimension. Finally, the embedding of an example is the concatenation of the contributions from each checkpoint.

Parameters:
  • model (torch.nn.Module) – An instance of pytorch model. This model should define all of its layers as attributes of the model.

  • final_fc_layer (torch.nn.Module) – The last fully connected layer in the network for which gradients will be approximated via fast random projection method.

  • train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader) – In the influence method, we compute the influence score of training examples on examples in a test batch. This argument represents the training dataset containing those training examples. In order to compute those influence scores, we will create a Pytorch DataLoader yielding batches of training examples that is then used for processing. If this argument is already a Pytorch Dataloader, that DataLoader can be directly used for processing. If it is instead a Pytorch Dataset, we will create a DataLoader using it, with batch size specified by batch_size. For efficiency purposes, the batch size of the DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if train_dataset is a Dataset, batch_size should be large. If train_dataset was already a DataLoader to begin with, it should have been constructed to have a large batch size. It is assumed that the Dataloader (regardless of whether it is created from a Pytorch Dataset or not) yields tuples. For a batch that is yielded, of length L, it is assumed that the forward function of model accepts L-1 arguments, and the last element of batch is the label. In other words, model(*batch[:-1]) gives the output of model, and batch[-1] are the labels for the batch.

  • checkpoints (str, list[str], or Iterator) – Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which returns objects from which to load checkpoints.

  • checkpoints_load_func (Callable, optional) – The function to load a saved checkpoint into a model to update its parameters, and get the learning rate if it is saved. By default uses a utility to load a model saved as a state dict. Default: _load_flexible_state_dict

  • loss_fn (Callable, optional) – The loss function applied to model. loss_fn must be a “reduction” loss function that reduces the per-example losses in a batch, and returns a single scalar Tensor. Furthermore, the reduction must be the sum of the per-example losses. For instance, nn.BCELoss(reduction=”sum”) is acceptable, but nn.BCELoss(reduction=”mean”) is not acceptable. Default: None

  • batch_size (int or None, optional) – Batch size of the DataLoader created to iterate through train_dataset, if it is a Dataset. batch_size should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of TracInCPBase will detail the size of the intermediate quantities. batch_size must be an int if train_dataset is a Dataset. If train_dataset is a DataLoader, then batch_size is ignored as an argument. Default: 1

  • test_loss_fn (Callable, optional) – In some cases, one may want to use a separate loss functions for training examples, i.e. those in train_dataset, and for test examples, i.e. those represented by the inputs and targets arguments to the influence method. For example, if one wants to calculate the influence score of a training example on a test example’s prediction for a fixed class, test_loss_fn could map from the logits for all classes to the logits for a fixed class. test_loss_fn needs satisfy the same constraints as loss_fn. Thus, the same checks that we apply to loss_fn are also applied to test_loss_fn, if the latter is provided. If not provided, the loss function for test examples is assumed to be the same as the loss function for training examples, i.e. loss_fn.

  • vectorize (bool) – Flag to use experimental vectorize functionality for torch.autograd.functional.jacobian. Default: False

  • nearest_neighbors (NearestNeighbors, optional) – The NearestNeighbors instance for finding nearest neighbors. If None, defaults to AnnoyNearestNeighbors(n_trees=10). Default: None

  • projection_dim (int, optional) – Each example will be represented in the nearest neighbors data structure with a vector. This vector is the concatenation of several “checkpoint vectors”, each of which is computed using a different checkpoint in the checkpoints argument. If projection_dim is an int, it represents the dimension we will project each “checkpoint vector” to, so that the vector for each example will be of dimension at most projection_dim * C, where C is the number of checkpoints. Regarding the dimension of each vector, D: Let I be the dimension of the output of the last fully-connected layer times the dimension of the input of the last fully-connected layer. If projection_dim is not None, then D = min(I * C, projection_dim * C). Otherwise, D = I * C. In summary, if projection_dim is None, the dimension of this vector will be determined by the size of the input and output of the last fully-connected layer of model, and the number of checkpoints. Otherwise, projection_dim must be an int, and random projection will be performed to ensure that the vector is of dimension no more than projection_dim * C. projection_dim corresponds to the variable d in the top of page 15 of the TracIn paper: https://arxiv.org/abs/2002.08484. Default: None

  • seed (int, optional) – Because this implementation chooses a random projection, its output is random. Setting this seed specifies the random seed when choosing the random projection. Default: 0

compute_intermediate_quantities(inputs)[source]

Computes “embedding” vectors for all examples in a single batch, or a Dataloader that yields batches. These embedding vectors are constructed so that the influence score of a training example on a test example is simply the dot-product of their corresponding vectors. Please see the documentation for TracInCPFastRandProj.__init__ for more details. Allowing a DataLoader yielding batches to be passed in (as opposed to a single batch) gives the potential to improve efficiency, because we load each checkpoint only once in this method call. Thus if a DataLoader yielding batches is passed in, this reduces the total number of times each checkpoint is loaded for a dataset, compared to if a single batch is passed in. The reason we do not just increase the batch size is that for large models, large batches do not fit in memory.

Parameters:

inputs (Tuple, or DataLoader) – Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and and batch[-1] are the labels, if any. Here, model is model provided in initialization. This is the same assumption made for each batch yielded by training dataset train_dataset. Please see documentation for the train_dataset argument to TracInCPFastRandProj.__init__ for more details on the assumed structure of a batch.

Returns:

A tensor of dimension

(N, D * C), where N is total number of examples in inputs, C is the number of checkpoints passed as the checkpoints argument of TracInCPFastRandProj.__init__, and each row represents the vector for an example. Regarding D: Let I be the dimension of the output of the last fully-connected layer times the dimension of the input of the last fully-connected layer. If self.projection_dim is specified in initialization, D = min(I * C, self.projection_dim * C). Otherwise, D = I * C. In summary, if self.projection_dim is None, the dimension of each vector will be determined by the size of the input and output of the last fully-connected layer of model. Otherwise, self.projection_dim must be an int, and random projection will be performed to ensure that the vector is of dimension no more than self.projection_dim * C. self.projection_dim corresponds to the variable d in the top of page 15 of the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf.

Return type:

intermediate_quantities (Tensor)

influence(inputs=None, k=5, proponents=True)[source]

This is the key method of this class, and can be run in 2 different modes, where the mode that is run depends on the arguments passed to this method:

  • influence score mode: This mode is used if k is None. This mode computes the influence score of every example in training dataset train_dataset on every example in the test batch represented by inputs.

  • k-most influential mode: This mode is used if k is not None, and an int. This mode computes the proponents or opponents of every example in the test batch represented by inputs. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the indices in the training dataset train_dataset of the training examples with the k highest (resp. lowest) influence scores on the test example. Proponents are computed if proponents is True. Otherwise, opponents are computed. For each test example, this method also returns the actual influence score of each proponent (resp. opponent) on the test example.

Parameters:
  • inputs (tuple) – inputs is the test batch and is a tuple of any, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset - please see its documentation in __init__ for more details on the assumed structure of a batch.

  • k (int, optional) – If not provided or None, the influence score mode will be run. Otherwise, the k-most influential mode will be run, and k is the number of proponents / opponents to return per example in the test batch. Default: None

  • proponents (bool, optional) – Whether seeking proponents (proponents=True) or opponents (proponents=False), if running in k-most influential mode. Default: True

Return type:

Union[Tensor, KMostInfluentialResults]

Returns:

The return value of this method depends on which mode is run.

  • influence score mode: if this mode is run (k is None), returns a 2D tensor influence_scores of shape (input_size, train_dataset_size), where input_size is the number of examples in the test batch, and train_dataset_size is the number of examples in training dataset train_dataset. In other words, influence_scores[i][j] is the influence score of the j-th example in train_dataset on the i-th example in the test batch.

  • k-most influential mode: if this mode is run (k is an int), returns a namedtuple (indices, influence_scores). indices is a 2D tensor of shape (input_size, k), where input_size is the number of examples in the test batch. If computing proponents (resp. opponents), indices[i][j] is the index in training dataset train_dataset of the example with the j-th highest (resp. lowest) influence score (out of the examples in train_dataset) on the i-th example in the test batch. influence_scores contains the corresponding influence scores. In particular, influence_scores[i][j] is the influence score of example indices[i][j] in train_dataset on example i in the test batch represented by inputs.

self_influence(inputs=None, show_progress=False, outer_loop_by_checkpoints=False)[source]

NOT IMPLEMENTED - no need to implement TracInCPFastRandProj.self_influence, as TracInCPFast.self_influence is sufficient - the latter does not benefit from random projections, since no quantities associated with a training example are stored (other than its self influence score)

Computes self influence scores for a single batch or a Pytorch DataLoader that yields batches. Note that if inputs is a single batch, this will call model on that single batch, and if inputs yields batches, this will call model on each batch that is yielded. Therefore, please ensure that for both cases, the batch(es) that model is called with are not too large, so that there will not be an out-of-memory error.

Parameters:
  • inputs (tuple or DataLoader) – Either a single tuple of any, or a DataLoader, where each batch yielded is a tuple of any. In either case, the tuple represents a single batch, where the last element is assumed to be the labels for the batch. That is, model(*batch[0:-1]) produces the output for model, and batch[-1] are the labels, if any. This is the same assumption made for each batch yielded by training dataset train_dataset. Please see documentation for the train_dataset argument to TracInCP.__init__ for more details on the assumed structure of a batch.

  • show_progress (bool, optional) – Computation of self influence scores can take a long time if inputs represents many examples. If show_progress is true, the progress of this computation will be displayed. In more detail, this computation will iterate over all checkpoints (provided as the checkpoints initialization argument) and all batches that inputs represents. Therefore, the total number of (checkpoint, batch) combinations that need to be iterated over is (# of checkpoints x # of batches that inputs represents). If show_progress is True, the total number of such combinations that have been iterated over is displayed. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False

  • outer_loop_by_checkpoints (bool, optional) – If performing an outer iteration over checkpoints; see method description for more details. Default: False

Returns:

This is a 1D tensor containing the self

influence scores of all examples in inputs, regardless of whether it represents a single batch or a DataLoader that yields batches.

Return type:

self_influence_scores (Tensor)