Source code for captum._utils.models.model
from abc import ABC, abstractmethod
from typing import Dict, Optional, Union
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from torch import Tensor
from torch.utils.data import DataLoader
Abstract Class to describe the interface of a trainable model to be used
within the algorithms of captum.
Please note that this is an experimental feature.
self, train_data: DataLoader, **kwargs
) -> Optional[Dict[str, Union[int, float, Tensor]]]:
Override this method to actually train your model.
The specification of the dataloader will be supplied by the algorithm
you are using within captum. This will likely be a supervised learning
task, thus you should expect batched (x, y) pairs or (x, y, w) triples.
The data to train on
Optional statistics about training, e.g. iterations it took to
train, training loss, etc.
def representation(self) -> Tensor:
Returns the underlying representation of the interpretable model. For a
linear model this is simply a tensor (the concatenation of weights
and bias). For something slightly more complicated, such as a decision
tree, this could be the nodes of a decision tree.
A Tensor describing the representation of the model.
self, x: TensorOrTupleOfTensorsGeneric
) -> TensorOrTupleOfTensorsGeneric:
Predicts with the interpretable model.
A batched input of tensor(s) to the model to predict
The prediction of the input as a TensorOrTupleOfTensorsGeneric.