# Source code for captum.attr._models.base

#!/usr/bin/env python3

import warnings
from functools import reduce

import torch
from torch.nn import Module

[docs]class InterpretableEmbeddingBase(Module):
r"""
Since some embedding vectors, e.g. word are created and assigned in
the embedding layers of Pytorch models we need a way to access
those layers, generate the embeddings and subtract the baseline.
To do so, we separate embedding layers from the model, compute the
embeddings separately and do all operations needed outside of the model.
The original embedding layer is being replaced by
InterpretableEmbeddingBase layer which passes already
precomputed embedding vectors to the layers below.
"""

def __init__(self, embedding, full_name) -> None:
Module.__init__(self)
self.num_embeddings = getattr(embedding, "num_embeddings", None)
self.embedding_dim = getattr(embedding, "embedding_dim", None)

self.embedding = embedding
self.full_name = full_name

[docs]    def forward(self, *inputs, **kwargs):
r"""
The forward function of a wrapper embedding layer that takes and returns
embedding layer. It allows embeddings to be created outside of the model
and passes them seamlessly to the preceding layers of the model.

Args:

*inputs (Any, optional): A sequence of inputs arguments that the
forward function takes. Since forward functions can take any
type and number of arguments, this will ensure that we can
execute the forward pass using interpretable embedding layer.
Note that if inputs are specified, it is assumed that the first
argument is the embedding tensor generated using the
self.embedding layer using all input arguments provided in
inputs and kwargs.
**kwargs (Any, optional): Similar to inputs we want to make sure
that our forward pass supports arbitrary number and type of
key-value arguments. If inputs is not provided, kwargs must
be provided and the first argument corresponds to the embedding
tensor generated using the self.embedding. Note that we make
here an assumption here that kwargs is an ordered dict which
is new in python 3.6 and is not guaranteed that it will
consistently remain that way in the newer versions. In case
current implementation doesn't work for special use cases,
it is encouraged to override InterpretableEmbeddingBase and
address those specifics in descendant classes.

Returns:

embedding_tensor (Tensor):
Returns a tensor which is the same as first argument passed
to the forward function.
It passes pre-computed embedding tensors to lower layers
without any modifications.
"""
assert len(inputs) > 0 or len(kwargs) > 0, (
"No input arguments are provided to InterpretableEmbeddingBase."
"Input embedding tensor has to be provided as first argument to forward "
"function either through inputs argument or kwargs."
)
return inputs[0] if len(inputs) > 0 else list(kwargs.values())[0]

[docs]    def indices_to_embeddings(self, *input, **kwargs):
r"""
Maps indices to corresponding embedding vectors. E.g. word embeddings

Args:

*input (Any, Optional): This can be a tensor(s) of input indices or any
other variable necessary to comput the embeddings. A typical
example of input indices are word or token indices.
**kwargs (Any, optional): Similar to input this can be any sequence
of key-value arguments necessary to compute final embedding
tensor.
Returns:

tensor:
A tensor of word embeddings corresponding to the
indices specified in the input
"""
return self.embedding(*input, **kwargs)

[docs]class TokenReferenceBase:
r"""
A base class for creating reference (aka baseline) tensor for a sequence of
tokens. A typical example of such token is PAD. Users need to provide the
index of the reference token in the vocabulary as an argument to
TokenReferenceBase class.
"""

def __init__(self, reference_token_idx=0) -> None:
self.reference_token_idx = reference_token_idx

[docs]    def generate_reference(self, sequence_length, device):
r"""
Generated reference tensor of given sequence_length using
reference_token_idx.

Args:
sequence_length (int): The length of the reference sequence
device (torch.device): The device on which the reference tensor will
be created.
Returns:

tensor:
A sequence of reference token with shape:
[sequence_length]
"""

def _get_deep_layer_name(obj, layer_names):
r"""
Traverses through the layer names that are separated by
dot in order to access the embedding layer.
"""
return reduce(getattr, layer_names.split("."), obj)

def _set_deep_layer_value(obj, layer_names, value):
r"""
Traverses through the layer names that are separated by
dot in order to access the embedding layer and update its value.
"""
layer_names = layer_names.split(".")
setattr(reduce(getattr, layer_names[:-1], obj), layer_names[-1], value)

[docs]def configure_interpretable_embedding_layer(model, embedding_layer_name="embedding"):
r"""
This method wraps model's embedding layer with an interpretable embedding
layer that allows us to access the embeddings through their indices.

Args:

model (torch.nn.Model): An instance of PyTorch model that contains embeddings.
embedding_layer_name (str, optional): The name of the embedding layer
in the model that we would like to make interpretable.

Returns:

interpretable_emb (tensor): An instance of InterpretableEmbeddingBase
embedding layer that wraps model's embedding layer that is being
accessed through embedding_layer_name.

Examples::

>>> # Let's assume that we have a DocumentClassifier model that
>>> # has a word embedding layer named 'embedding'.
>>> # To make that layer interpretable we need to execute the
>>> # following command:
>>> net = DocumentClassifier()
>>> interpretable_emb = configure_interpretable_embedding_layer(net,
>>>    'embedding')
>>> # then we can use interpretable embedding to convert our
>>> # word indices into embeddings.
>>> # Let's assume that we have the following word indices
>>> input_indices = torch.tensor([1, 0, 2])
>>> # we can access word embeddings for those indices with the command
>>> # line stated below.
>>> input_emb = interpretable_emb.indices_to_embeddings(input_indices)
>>> # Let's assume that we want to apply integrated gradients to
>>> # our model and that target attribution class is 3
>>> # after we finish the interpretation we need to remove
>>> # interpretable embedding layer with the following command:
>>> remove_interpretable_embedding_layer(net, interpretable_emb)

"""
embedding_layer = _get_deep_layer_name(model, embedding_layer_name)
assert (
embedding_layer.__class__ is not InterpretableEmbeddingBase
), "InterpretableEmbeddingBase has already been configured for layer {}".format(
embedding_layer_name
)
warnings.warn(
"In order to make embedding layers more interpretable they will "
"be replaced with an interpretable embedding layer which wraps the "
"original embedding layer and takes word embedding vectors as inputs of "
"the forward function. This allows us to generate baselines for word "
"embeddings and compute attributions for each embedding dimension. "
"The original embedding layer must be set "
"back by calling remove_interpretable_embedding_layer function "
"after model interpretation is finished. "
)
interpretable_emb = InterpretableEmbeddingBase(
embedding_layer, embedding_layer_name
)
_set_deep_layer_value(model, embedding_layer_name, interpretable_emb)
return interpretable_emb

[docs]def remove_interpretable_embedding_layer(model, interpretable_emb):
r"""
Removes interpretable embedding layer and sets back original
embedding layer in the model.

Args:

model (torch.nn.Module): An instance of PyTorch model that contains embeddings
interpretable_emb (tensor): An instance of InterpretableEmbeddingBase
that was originally created in
configure_interpretable_embedding_layer function and has
to be removed after interpretation is finished.

Examples::

>>> # Let's assume that we have a DocumentClassifier model that
>>> # has a word embedding layer named 'embedding'.
>>> # To make that layer interpretable we need to execute the
>>> # following command:
>>> net = DocumentClassifier()
>>> interpretable_emb = configure_interpretable_embedding_layer(net,
>>>    'embedding')
>>> # then we can use interpretable embedding to convert our
>>> # word indices into embeddings.
>>> # Let's assume that we have the following word indices
>>> input_indices = torch.tensor([1, 0, 2])
>>> # we can access word embeddings for those indices with the command
>>> # line stated below.
>>> input_emb = interpretable_emb.indices_to_embeddings(input_indices)
>>> # Let's assume that we want to apply integrated gradients to
>>> # our model and that target attribution class is 3