` at all. Therefore, we have to use `LimeBase` to customize the conversion logic through the `from_interp_rep_transform` argument.
#
# `LimeBase` is not opinionated at all so we have to define every piece manually. Let us talk about them in order:
# - `forward_func`, the forward function of the model. Notice we cannot pass our model directly since Captum always assumes the first dimension is batch while our embedding-bag requires flattened indices. So we will add the dummy dimension later when calling `attribute` and make a wrapper here to remove the dummy dimension before giving to our model.
# - `interpretable_model`, the surrogate model. This works the same as we demonstrated in the above image classification example. We also use sklearn linear lasso here.
# - `similarity_func`, the function calculating the weights for training samples. The most common distance used for texts is the cosine similarity in their latent embedding space. The text inputs are just sequences of token indices, so we have to leverage the trained embedding layer from the model to encode them to their latent vectors. Due to this extra encoding step, we cannot use the util `get_exp_kernel_similarity_function('cosine')` like in the image classification example, which directly calculate the cosine similarity of the given inputs.
# - `perturb_func`, the function to sample interpretable representations. We present another way to define this argument other than using generator as shown in the above image classification example. Here we directly define a function returning a randomized sample every call. It outputs a binary vector where each token is selected independently and uniformly at random.
# - `perturb_interpretable_space`, whether perturbed samples are in interpretable space. `LimeBase` also supports sampling in the original input space, but we do not need it in our case.
# - `from_interp_rep_transform`, the function transforming the perturbed interpretable samples back to the original input space. As explained above, this argument is the main reason for us to use `LimeBase`. We pick the subset of the present tokens from the original text input according to the interpretable representation.
# - `to_interp_rep_transform`, the opposite of `from_interp_rep_transform`. It is needed only when `perturb_interpretable_space` is set to false.
# In[27]:
# remove the batch dimension for the embedding-bag model
def forward_func(text, offsets):
return eb_model(text.squeeze(0), offsets)
# encode text indices into latent representations & calculate cosine similarity
def exp_embedding_cosine_distance(original_inp, perturbed_inp, _, **kwargs):
original_emb = eb_model.embedding(original_inp, None)
perturbed_emb = eb_model.embedding(perturbed_inp, None)
distance = 1 - F.cosine_similarity(original_emb, perturbed_emb, dim=1)
return torch.exp(-1 * (distance ** 2) / 2)
# binary vector where each word is selected independently and uniformly at random
def bernoulli_perturb(text, **kwargs):
probs = torch.ones_like(text) * 0.5
return torch.bernoulli(probs).long()
# remove absenst token based on the intepretable representation sample
def interp_to_input(interp_sample, original_input, **kwargs):
return original_input[interp_sample.bool()].view(original_input.size(0), -1)
lasso_lime_base = LimeBase(
forward_func,
interpretable_model=SkLearnLasso(alpha=0.08),
similarity_func=exp_embedding_cosine_distance,
perturb_func=bernoulli_perturb,
perturb_interpretable_space=True,
from_interp_rep_transform=interp_to_input,
to_interp_rep_transform=None
)
# The attribution call is the same as the `Lime` class. Just remember to add the dummy batch dimension to the text input and put the offsets in the `additional_forward_args` because it is not a feature for the classification but a metadata for the text input.
# In[28]:
attrs = lasso_lime_base.attribute(
test_text.unsqueeze(0), # add batch dimension for Captum
target=test_labels,
additional_forward_args=(test_offsets,),
n_samples=32000,
show_progress=True
).squeeze(0)
print('Attribution range:', attrs.min().item(), 'to', attrs.max().item())
# At last, let us create a simple visualization to highlight the influential words where green stands for positive correlation and red for negative.
# In[29]:
def show_text_attr(attrs):
rgb = lambda x: '255,0,0' if x < 0 else '0,255,0'
alpha = lambda x: abs(x) ** 0.5
token_marks = [
f'{token}'
for token, attr in zip(tokenizer(test_line), attrs.tolist())
]
display(HTML('' + ' '.join(token_marks) + '

'))
show_text_attr(attrs)
# The above visulization should render something like the image below where the model links the "Sports" subject to many reasonable words, like "match" and "medals".
#
# ![Lime Text](img/lime_text_viz.png)