Source code for captum.influence._core.influence

#!/usr/bin/env python3

# pyre-strict

from abc import ABC, abstractmethod
from typing import Any, Type

from torch.nn import Module
from torch.utils.data import Dataset


[docs] class DataInfluence(ABC): r""" An abstract class to define model data influence skeleton. """ def __init_(self, model: Module, train_dataset: Dataset, **kwargs: Any) -> None: r""" Args: model (torch.nn.Module): An instance of pytorch model. train_dataset (torch.utils.data.Dataset): PyTorch Dataset that is used to create a PyTorch Dataloader to iterate over the dataset and its labels. This is the dataset for which we will be seeking for influential instances. In most cases this is the training dataset. **kwargs: Additional key-value arguments that are necessary for specific implementation of `DataInfluence` abstract class. """ # pyre-fixme[16]: `DataInfluence` has no attribute `model`. self.model = model # pyre-fixme[16]: `DataInfluence` has no attribute `train_dataset`. self.train_dataset = train_dataset
[docs] @abstractmethod # pyre-fixme[3]: Return annotation cannot be `Any`. # pyre-fixme[2]: Parameter annotation cannot be `Any`. def influence(self, inputs: Any = None, **kwargs: Any) -> Any: r""" Args: inputs (Any): Batch of examples for which influential instances are computed. They are passed to the forward_func. If `inputs` if a tensor or tuple of tensors, the first dimension of a tensor corresponds to the batch dimension. **kwargs: Additional key-value arguments that are necessary for specific implementation of `DataInfluence` abstract class. Returns: influences (Any): We do not add restrictions on the return type for now, though this may change in the future. """ pass
[docs] @classmethod def get_name(cls: Type["DataInfluence"]) -> str: r""" Create readable class name. Due to the nature of the names of `TracInCPBase` subclasses, simply returns the class name. For example, for a class called TracInCP, we return the string TracInCP. Returns: name (str): a readable class name """ return cls.__name__