Source code for captum.attr._core.llm_attr

# pyre-strict

import warnings

from abc import ABC

from copy import copy

from textwrap import shorten

from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

import matplotlib.colors as mcolors

import matplotlib.pyplot as plt
import numpy as np

import torch
from captum._utils.typing import TokenizerLike
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import (
    Attribution,
    GradientAttribution,
    PerturbationAttribution,
)
from captum.attr._utils.interpretable_input import (
    InterpretableInput,
    TextTemplateInput,
    TextTokenInput,
)
from torch import nn, Tensor

DEFAULT_GEN_ARGS: Dict[str, Any] = {
    "max_new_tokens": 25,
    "do_sample": False,
    "temperature": None,
    "top_p": None,
}


[docs] class LLMAttributionResult: """ Data class for the return result of LLMAttribution, which includes the necessary properties of the attribution. It also provides utilities to help present and plot the result in different forms. """ def __init__( self, seq_attr: Tensor, token_attr: Optional[Tensor], input_tokens: List[str], output_tokens: List[str], ) -> None: self.seq_attr = seq_attr self.token_attr = token_attr self.input_tokens = input_tokens self.output_tokens = output_tokens @property def seq_attr_dict(self) -> Dict[str, float]: return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}
[docs] def plot_token_attr( self, show: bool = False ) -> Union[None, Tuple[plt.Figure, plt.Axes]]: """ Generate a matplotlib plot for visualising the attribution of the output tokens. Args: show (bool): whether to show the plot directly or return the figure and axis Default: False """ if self.token_attr is None: raise ValueError( "token_attr is None (no token-level attribution was performed), please " "use plot_seq_attr instead for the sequence-level attribution plot" ) token_attr = self.token_attr.cpu() # maximum absolute attribution value # used as the boundary of normalization # always keep 0 as the mid point to differentiate pos/neg attr max_abs_attr_val = token_attr.abs().max().item() fig, ax = plt.subplots() # Hide the grid ax.grid(False) # Plot the heatmap data = token_attr.numpy() fig.set_size_inches( max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) ) colors = [ "#93003a", "#d0365b", "#f57789", "#ffbdc3", "#ffffff", "#a4d6e1", "#73a3ca", "#4772b3", "#00429d", ] im = ax.imshow( data, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, cmap=mcolors.LinearSegmentedColormap.from_list( name="colors", colors=colors ), aspect="auto", ) fig.set_facecolor("white") # Create colorbar cbar = fig.colorbar(im, ax=ax) # type: ignore cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom") # Show all ticks and label them with the respective list entries. shortened_tokens = [ shorten(t, width=50, placeholder="...") for t in self.input_tokens ] ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens) ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. for i in range(data.shape[0]): for j in range(data.shape[1]): val = data[i, j] color = "black" if 0.2 < im.norm(val) < 0.8 else "white" im.axes.text( j, i, "%.4f" % val, horizontalalignment="center", verticalalignment="center", color=color, ) if show: plt.show() return None # mypy wants this else: return fig, ax
[docs] def plot_seq_attr( self, show: bool = False ) -> Union[None, Tuple[plt.Figure, plt.Axes]]: """ Generate a matplotlib plot for visualising the attribution of the output sequence. Args: show (bool): whether to show the plot directly or return the figure and axis Default: False """ fig, ax = plt.subplots() data = self.seq_attr.cpu().numpy() fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8)) shortened_tokens = [ shorten(t, width=50, placeholder="...") for t in self.input_tokens ] ax.set_xticks(range(data.shape[0]), labels=shortened_tokens) ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) plt.setp( ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor", ) fig.set_facecolor("white") # pos bar ax.bar( range(data.shape[0]), [max(v, 0) for v in data], align="center", color="#4772b3", ) # neg bar ax.bar( range(data.shape[0]), [min(v, 0) for v in data], align="center", color="#d0365b", ) ax.set_ylabel("Sequence Attribution", rotation=90, va="bottom") if show: plt.show() return None # mypy wants this else: return fig, ax
def _clean_up_pretty_token(token: str) -> str: """Remove newlines and leading/trailing whitespace from token.""" return token.replace("\n", "\\n").strip() def _encode_with_offsets( txt: str, tokenizer: TokenizerLike, add_special_tokens: bool = True, **kwargs: Any, ) -> Tuple[List[int], List[Tuple[int, int]]]: enc = tokenizer( txt, return_offsets_mapping=True, add_special_tokens=add_special_tokens, **kwargs, ) input_ids = cast(List[int], enc["input_ids"]) offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"]) assert len(input_ids) == len(offset_mapping), ( f"{len(input_ids)} != {len(offset_mapping)}: {txt} -> " f"{input_ids}, {offset_mapping}" ) # For the case where offsets are not set properly (the end and start are # equal for all tokens - fall back on the start of the next span in the # offset mapping) offset_mapping_corrected = [] for i, (start, end) in enumerate(offset_mapping): if start == end: if (i + 1) < len(offset_mapping): end = offset_mapping[i + 1][0] else: end = len(txt) offset_mapping_corrected.append((start, end)) return input_ids, offset_mapping_corrected def _convert_ids_to_pretty_tokens( ids: Tensor, tokenizer: TokenizerLike, ) -> List[str]: """ Convert ids to tokens without ugly unicode characters (e.g., Ġ). See: https://github.com/huggingface/transformers/issues/4786 and https://discuss.huggingface.co/t/bpe-tokenizers-and-spaces-before-words/475/2 This is the preferred function over tokenizer.convert_ids_to_tokens() for user-facing data. Quote from links: > Spaces are converted in a special character (the Ġ) in the tokenizer prior to > BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm > used spaces in its process """ txt = tokenizer.decode(ids) input_ids: Optional[List[int]] = None # Don't add special tokens (they're either already there, or we don't want them) input_ids, offset_mapping = _encode_with_offsets( txt, tokenizer, add_special_tokens=False ) pretty_tokens = [] end_prev = -1 idx = 0 for i, offset in enumerate(offset_mapping): start, end = offset if input_ids[i] != ids[idx]: # When the re-encoded string doesn't match the original encoding we skip # this token and hope for the best, falling back on a naive method. This # can happen when a tokenizer might add a token that corresponds to # a space only when add_special_tokens=False. warnings.warn( f"(i={i}, idx={idx}) input_ids[i] {input_ids[i]} != ids[idx] " f"{ids[idx]} (corresponding to text: {repr(txt[start:end])}). " "Skipping this token.", stacklevel=2, ) continue pretty_tokens.append( _clean_up_pretty_token(txt[start:end]) + (" [OVERLAP]" if end_prev > start else "") ) end_prev = end idx += 1 if len(pretty_tokens) != len(ids): warnings.warn( f"Pretty tokens length {len(pretty_tokens)} != ids length {len(ids)}! " "Falling back to naive decoding logic.", stacklevel=2, ) return _convert_ids_to_pretty_tokens_fallback(ids, tokenizer) return pretty_tokens def _convert_ids_to_pretty_tokens_fallback( ids: Tensor, tokenizer: TokenizerLike ) -> List[str]: """ Fallback function that naively handles logic when multiple ids map to one string. """ pretty_tokens = [] idx = 0 while idx < len(ids): decoded = tokenizer.decode(ids[idx]) decoded_pretty = _clean_up_pretty_token(decoded) # Handle case where single token (e.g. unicode) is split into multiple IDs # NOTE: This logic will fail if a tokenizer splits a token into 3+ IDs if decoded.strip() == "�" and tokenizer.encode(decoded) != [ids[idx]]: # ID at idx is split, ensure next token is also from a split decoded_next = tokenizer.decode(ids[idx + 1]) if decoded_next.strip() == "�" and tokenizer.encode(decoded_next) != [ ids[idx + 1] ]: # Both tokens are from a split, combine them decoded = tokenizer.decode(ids[idx : idx + 2]) pretty_tokens.append(decoded_pretty) pretty_tokens.append(decoded_pretty + " [OVERLAP]") else: # Treat tokens as separate pretty_tokens.append(decoded_pretty) pretty_tokens.append(_clean_up_pretty_token(decoded_next)) idx += 2 else: # Just a normal token idx += 1 pretty_tokens.append(decoded_pretty) return pretty_tokens class BaseLLMAttribution(Attribution, ABC): """Base class for LLM Attribution methods""" SUPPORTED_INPUTS: Tuple[Type[InterpretableInput], ...] SUPPORTED_METHODS: Tuple[Type[Attribution], ...] model: nn.Module tokenizer: TokenizerLike device: torch.device def __init__( self, attr_method: Attribution, tokenizer: TokenizerLike, ) -> None: assert isinstance( attr_method, self.SUPPORTED_METHODS ), f"{self.__class__.__name__} does not support {type(attr_method)}" super().__init__(attr_method.forward_func) # alias, we really need a model and don't support wrapper functions # coz we need call model.forward, model.generate, etc. self.model: nn.Module = cast(nn.Module, self.forward_func) self.tokenizer: TokenizerLike = tokenizer self.device: torch.device = ( cast(torch.device, self.model.device) if hasattr(self.model, "device") else next(self.model.parameters()).device ) def _get_target_tokens( self, inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, skip_tokens: Union[List[int], List[str], None] = None, gen_args: Optional[Dict[str, Any]] = None, ) -> Tensor: assert isinstance( inp, self.SUPPORTED_INPUTS ), f"LLMAttribution does not support input type {type(inp)}" if target is None: # generate when None assert hasattr(self.model, "generate") and callable(self.model.generate), ( "The model does not have recognizable generate function." "Target must be given for attribution" ) if not gen_args: gen_args = DEFAULT_GEN_ARGS model_inp = self._format_model_input(inp.to_model_input()) # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. output_tokens = self.model.generate(model_inp, **gen_args) target_tokens = output_tokens[0][model_inp.size(1) :] else: assert gen_args is None, "gen_args must be None when target is given" # Encode skip tokens if skip_tokens: if isinstance(skip_tokens[0], str): skip_tokens = cast(List[str], skip_tokens) skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens) else: skip_tokens = [] skip_tokens = cast(List[int], skip_tokens) if isinstance(target, str): encoded = self.tokenizer.encode(target) target_tokens = torch.tensor( [token for token in encoded if token not in skip_tokens] ) elif isinstance(target, torch.Tensor): target_tokens = target[ ~torch.isin(target, torch.tensor(skip_tokens, device=target.device)) ] else: raise TypeError( "target must either be str or Tensor, but the type of target is " "{}".format(type(target)) ) return target_tokens def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor: """ Convert str to tokenized tensor to make LLMAttribution work with model inputs of both raw text and text token tensors """ # return tensor(1, n_tokens) if isinstance(model_input, str): return self.tokenizer.encode(model_input, return_tensors="pt").to( self.device ) return model_input.to(self.device)
[docs] class LLMAttribution(BaseLLMAttribution): """ Attribution class for large language models. It wraps a perturbation-based attribution algorthm to produce commonly interested attribution results for the use case of text generation. The wrapped instance will calculate attribution in the same way as configured in the original attribution algorthm, but it will provide a new "attribute" function which accepts text-based inputs and returns LLMAttributionResult """ SUPPORTED_METHODS = ( FeatureAblation, ShapleyValueSampling, ShapleyValues, Lime, KernelShap, ) SUPPORTED_PER_TOKEN_ATTR_METHODS = ( FeatureAblation, ShapleyValueSampling, ShapleyValues, ) SUPPORTED_INPUTS = (TextTemplateInput, TextTokenInput) def __init__( self, attr_method: PerturbationAttribution, tokenizer: TokenizerLike, attr_target: str = "log_prob", # TODO: support callable attr_target ) -> None: """ Args: attr_method (Attribution): Instance of a supported perturbation attribution Supported methods include FeatureAblation, ShapleyValueSampling, ShapleyValues, Lime, and KernelShap. Lime and KernelShap do not support per-token attribution and will only return attribution for the full target sequence. class created with the llm model that follows huggingface style interface convention tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method attr_target (str): attribute towards log probability or probability. Available values ["log_prob", "prob"] Default: "log_prob" """ super().__init__(attr_method, tokenizer) # shallow copy is enough to avoid modifying original instance self.attr_method: PerturbationAttribution = copy(attr_method) self.include_per_token_attr: bool = isinstance( attr_method, self.SUPPORTED_PER_TOKEN_ATTR_METHODS ) self.attr_method.forward_func = self._forward_func assert attr_target in ( "log_prob", "prob", ), "attr_target should be either 'log_prob' or 'prob'" self.attr_target = attr_target def _forward_func( self, perturbed_tensor: Union[None, Tensor], inp: InterpretableInput, target_tokens: Tensor, use_cached_outputs: bool = False, _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, ) -> Tensor: # Lazily import transformers_typing to avoid importing transformers package if # it isn't needed from captum._utils.transformers_typing import ( Cache, DynamicCache, supports_caching, update_model_kwargs, ) perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) init_model_inp = perturbed_input model_inp = init_model_inp attention_mask = torch.ones( [1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device ) model_kwargs = {"attention_mask": attention_mask} # If applicable, update model kwargs for transformers models update_model_kwargs( model_kwargs=model_kwargs, model=self.model, input_ids=model_inp, caching=use_cached_outputs, ) log_prob_list: List[Tensor] = [] outputs = None for target_token in target_tokens: if use_cached_outputs: if outputs is not None: # If applicable, convert past_key_values to DynamicCache for # transformers models if ( Cache is not None and DynamicCache is not None and supports_caching(self.model) and not isinstance(outputs.past_key_values, Cache) ): outputs.past_key_values = DynamicCache.from_legacy_cache( outputs.past_key_values ) # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. model_kwargs = self.model._update_model_kwargs_for_generation( outputs, model_kwargs ) # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. model_inputs = self.model.prepare_inputs_for_generation( model_inp, **model_kwargs ) outputs = self.model.forward(**model_inputs) else: # Update attention mask to adapt to input size change attention_mask = torch.ones( [1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device ) model_kwargs["attention_mask"] = attention_mask outputs = self.model.forward(model_inp, **model_kwargs) new_token_logits = outputs.logits[:, -1] log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) log_prob_list.append(log_probs[0][target_token].detach()) model_inp = torch.cat( (model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1 ) total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0) # 1st element is the total prob, rest are the target tokens # add a leading dim for batch even we only support single instance for now if self.include_per_token_attr: target_log_probs = torch.stack( [total_log_prob, *log_prob_list], dim=0 ).unsqueeze(0) else: target_log_probs = total_log_prob target_probs = torch.exp(target_log_probs) if _inspect_forward: prompt = self.tokenizer.decode(init_model_inp[0]) response = self.tokenizer.decode(target_tokens) # callback for externals to inspect (prompt, response, seq_prob) _inspect_forward(prompt, response, target_probs[0].tolist()) return target_probs if self.attr_target != "log_prob" else target_log_probs
[docs] def attribute( self, inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, skip_tokens: Union[List[int], List[str], None] = None, num_trials: int = 1, gen_args: Optional[Dict[str, Any]] = None, use_cached_outputs: bool = True, # internal callback hook can be used for logging _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, **kwargs: Any, ) -> LLMAttributionResult: """ Args: inp (InterpretableInput): input prompt for which attributions are computed target (str or Tensor, optional): target response with respect to which attributions are computed. If None, it uses the model to generate the target based on the input and gen_args. Default: None skip_tokens (List[int] or List[str], optional): the tokens to skip in the the output's interpretable representation. Use this argument to define uninterested tokens, commonly like special tokens, e.g., sos, and unk. It can be a list of strings of the tokens or a list of integers of the token ids. Default: None num_trials (int, optional): number of trials to run. Return is the average attributions over all the trials. Defaults: 1. gen_args (dict, optional): arguments for generating the target. Only used if target is not given. When None, the default arguments are used, {"max_new_tokens": 25, "do_sample": False, "temperature": None, "top_p": None} Defaults: None **kwargs (Any): any extra keyword arguments passed to the call of the underlying attribute function of the given attribution instance Returns: attr (LLMAttributionResult): Attribution result. token_attr will be None if attr method is Lime or KernelShap. """ target_tokens = self._get_target_tokens( inp, target, skip_tokens=skip_tokens, gen_args=gen_args, ) attr = torch.zeros( [ 1 + len(target_tokens) if self.include_per_token_attr else 1, inp.n_itp_features, ], dtype=torch.float, device=self.device, ) for _ in range(num_trials): attr_input = inp.to_tensor().to(self.device) cur_attr = self.attr_method.attribute( attr_input, additional_forward_args=( inp, target_tokens, use_cached_outputs, _inspect_forward, ), **kwargs, ) # temp necessary due to FA & Shapley's different return shape of multi-task # FA will flatten output shape internally (n_output_token, n_itp_features) # Shapley will keep output shape (batch, n_output_token, n_input_features) cur_attr = cur_attr.reshape(attr.shape) attr += cur_attr attr = attr / num_trials attr = inp.format_attr(attr) return LLMAttributionResult( attr[0], ( attr[1:] if self.include_per_token_attr else None ), # shape(n_output_token, n_input_features) inp.values, _convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), )
[docs] def attribute_future(self) -> Callable[[], LLMAttributionResult]: r""" This method is not implemented for LLMAttribution. """ raise NotImplementedError( "attribute_future is not implemented for LLMAttribution" )
[docs] class LLMGradientAttribution(BaseLLMAttribution): """ Attribution class for large language models. It wraps a gradient-based attribution algorthm to produce commonly interested attribution results for the use case of text generation. The wrapped instance will calculate attribution in the same way as configured in the original attribution algorthm, with respect to the log probabilities of each generated token and the whole sequence. It will provide a new "attribute" function which accepts text-based inputs and returns LLMAttributionResult """ SUPPORTED_METHODS = ( LayerGradientShap, LayerGradientXActivation, LayerIntegratedGradients, ) SUPPORTED_INPUTS = (TextTokenInput,) def __init__( self, attr_method: GradientAttribution, tokenizer: TokenizerLike, ) -> None: """ Args: attr_method (Attribution): instance of a supported perturbation attribution class created with the llm model that follows huggingface style interface convention tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method """ super().__init__(attr_method, tokenizer) # shallow copy is enough to avoid modifying original instance self.attr_method: GradientAttribution = copy(attr_method) self.attr_method.forward_func = GradientForwardFunc(self)
[docs] def attribute( self, inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, skip_tokens: Union[List[int], List[str], None] = None, gen_args: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> LLMAttributionResult: """ Args: inp (InterpretableInput): input prompt for which attributions are computed target (str or Tensor, optional): target response with respect to which attributions are computed. If None, it uses the model to generate the target based on the input and gen_args. Default: None skip_tokens (List[int] or List[str], optional): the tokens to skip in the the output's interpretable representation. Use this argument to define uninterested tokens, commonly like special tokens, e.g., sos, and unk. It can be a list of strings of the tokens or a list of integers of the token ids. Default: None gen_args (dict, optional): arguments for generating the target. Only used if target is not given. When None, the default arguments are used, {"max_new_tokens": 25, "do_sample": False, "temperature": None, "top_p": None} Defaults: None **kwargs (Any): any extra keyword arguments passed to the call of the underlying attribute function of the given attribution instance Returns: attr (LLMAttributionResult): attribution result """ target_tokens = self._get_target_tokens( inp, target, skip_tokens=skip_tokens, gen_args=gen_args, ) attr_inp = inp.to_tensor().to(self.device) attr_list = [] for cur_target_idx, _ in enumerate(target_tokens): # attr in shape(batch_size, input+output_len, emb_dim) attr = self.attr_method.attribute( attr_inp, additional_forward_args=( inp, target_tokens, cur_target_idx, ), **kwargs, ).detach() attr = cast(Tensor, attr) # will have the attr for previous output tokens # cut to shape(batch_size, inp_len, emb_dim) if cur_target_idx: attr = attr[:, :-cur_target_idx] # the author of IG uses sum # https://github.com/ankurtaly/Integrated-Gradients/blob/master/BertModel/bert_model_utils.py#L350 attr = attr.sum(-1) attr_list.append(attr) # assume inp batch only has one instance # to shape(n_output_token, ...) attr = torch.cat(attr_list, dim=0) # grad attr method do not care the length of features in interpretable format # it attributes to all the elements of the output of the specified layer # so we need special handling for the inp type which don't care all the elements if isinstance(inp, TextTokenInput) and inp.itp_mask is not None: itp_mask = inp.itp_mask.to(attr.device) itp_mask = itp_mask.expand_as(attr) attr = attr[itp_mask].view(attr.size(0), -1) # for all the gradient methods we support in this class # the seq attr is the sum of all the token attr if the attr_target is log_prob, # shape(n_input_features) seq_attr = attr.sum(0) return LLMAttributionResult( seq_attr, attr, # shape(n_output_token, n_input_features) inp.values, _convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), )
[docs] def attribute_future(self) -> Callable[[], LLMAttributionResult]: r""" This method is not implemented for LLMGradientAttribution. """ raise NotImplementedError( "attribute_future is not implemented for LLMGradientAttribution" )
class GradientForwardFunc(nn.Module): """ A wrapper class for the forward function of a model in LLMGradientAttribution """ def __init__(self, attr: LLMGradientAttribution) -> None: super().__init__() self.attr = attr self.model: nn.Module = attr.model def forward( self, perturbed_tensor: Tensor, inp: InterpretableInput, target_tokens: Tensor, # 1D tensor of target token ids cur_target_idx: int, # current target index ) -> Tensor: perturbed_input = self.attr._format_model_input( inp.to_model_input(perturbed_tensor) ) if cur_target_idx: # the input batch size can be expanded by attr method output_token_tensor = ( target_tokens[:cur_target_idx] .unsqueeze(0) .expand(perturbed_input.size(0), -1) .to(self.attr.device) ) new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1) else: new_input_tensor = perturbed_input output_logits = self.model(new_input_tensor) new_token_logits = output_logits.logits[:, -1] log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) target_token = target_tokens[cur_target_idx] token_log_probs = log_probs[..., target_token] # the attribution target is limited to the log probability return token_log_probs