Source code for captum.attr._core.llm_attr

# pyre-strict
from copy import copy

from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np

import torch
from captum._utils.typing import TokenizerLike
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import (
    Attribution,
    GradientAttribution,
    PerturbationAttribution,
)
from captum.attr._utils.interpretable_input import (
    InterpretableInput,
    TextTemplateInput,
    TextTokenInput,
)
from torch import nn, Tensor


DEFAULT_GEN_ARGS = {"max_new_tokens": 25, "do_sample": False}


[docs] class LLMAttributionResult: """ Data class for the return result of LLMAttribution, which includes the necessary properties of the attribution. It also provides utilities to help present and plot the result in different forms. """ def __init__( self, seq_attr: Tensor, token_attr: Optional[Tensor], input_tokens: List[str], output_tokens: List[str], ) -> None: self.seq_attr = seq_attr self.token_attr = token_attr self.input_tokens = input_tokens self.output_tokens = output_tokens @property def seq_attr_dict(self) -> Dict[str, float]: return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}
[docs] def plot_token_attr( self, show: bool = False ) -> Union[None, Tuple[plt.Figure, plt.Axes]]: """ Generate a matplotlib plot for visualising the attribution of the output tokens. Args: show (bool): whether to show the plot directly or return the figure and axis Default: False """ if self.token_attr is None: raise ValueError( "token_attr is None (no token-level attribution was performed), please " "use plot_seq_attr instead for the sequence-level attribution plot" ) token_attr = self.token_attr.cpu() # type: ignore # maximum absolute attribution value # used as the boundary of normalization # always keep 0 as the mid point to differentiate pos/neg attr max_abs_attr_val = token_attr.abs().max().item() fig, ax = plt.subplots() # Plot the heatmap data = token_attr.numpy() fig.set_size_inches( max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) ) im = ax.imshow( data, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, cmap="RdYlGn", aspect="auto", ) # Create colorbar cbar = fig.colorbar(im, ax=ax) # type: ignore cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom") # Show all ticks and label them with the respective list entries. ax.set_xticks(np.arange(data.shape[1]), labels=self.input_tokens) ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. for i in range(data.shape[0]): for j in range(data.shape[1]): val = data[i, j] color = "black" if 0.2 < im.norm(val) < 0.8 else "white" im.axes.text( j, i, "%.4f" % val, horizontalalignment="center", verticalalignment="center", color=color, ) if show: plt.show() return None # mypy wants this else: return fig, ax
[docs] def plot_seq_attr( self, show: bool = False ) -> Union[None, Tuple[plt.Figure, plt.Axes]]: """ Generate a matplotlib plot for visualising the attribution of the output sequence. Args: show (bool): whether to show the plot directly or return the figure and axis Default: False """ fig, ax = plt.subplots() data = self.seq_attr.cpu().numpy() ax.set_xticks(range(data.shape[0]), labels=self.input_tokens) ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # pos bar ax.bar( range(data.shape[0]), [max(v, 0) for v in data], align="center", color="g" ) # neg bar ax.bar( range(data.shape[0]), [min(v, 0) for v in data], align="center", color="r" ) ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom") if show: plt.show() return None # mypy wants this else: return fig, ax
[docs] class LLMAttribution(Attribution): """ Attribution class for large language models. It wraps a perturbation-based attribution algorthm to produce commonly interested attribution results for the use case of text generation. The wrapped instance will calculate attribution in the same way as configured in the original attribution algorthm, but it will provide a new "attribute" function which accepts text-based inputs and returns LLMAttributionResult """ SUPPORTED_METHODS = ( FeatureAblation, ShapleyValueSampling, ShapleyValues, Lime, KernelShap, ) SUPPORTED_PER_TOKEN_ATTR_METHODS = ( FeatureAblation, ShapleyValueSampling, ShapleyValues, ) SUPPORTED_INPUTS = (TextTemplateInput, TextTokenInput) def __init__( self, attr_method: PerturbationAttribution, tokenizer: TokenizerLike, attr_target: str = "log_prob", # TODO: support callable attr_target ) -> None: """ Args: attr_method (Attribution): Instance of a supported perturbation attribution Supported methods include FeatureAblation, ShapleyValueSampling, ShapleyValues, Lime, and KernelShap. Lime and KernelShap do not support per-token attribution and will only return attribution for the full target sequence. class created with the llm model that follows huggingface style interface convention tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method attr_target (str): attribute towards log probability or probability. Available values ["log_prob", "prob"] Default: "log_prob" """ assert isinstance( attr_method, self.SUPPORTED_METHODS ), f"LLMAttribution does not support {type(attr_method)}" super().__init__(attr_method.forward_func) # shallow copy is enough to avoid modifying original instance self.attr_method: PerturbationAttribution = copy(attr_method) self.include_per_token_attr: bool = isinstance( attr_method, self.SUPPORTED_PER_TOKEN_ATTR_METHODS ) self.attr_method.forward_func = self._forward_func # alias, we really need a model and don't support wrapper functions # coz we need call model.forward, model.generate, etc. self.model: nn.Module = cast(nn.Module, self.forward_func) self.tokenizer: TokenizerLike = tokenizer self.device: torch.device = ( cast(torch.device, self.model.device) if hasattr(self.model, "device") else next(self.model.parameters()).device ) assert attr_target in ( "log_prob", "prob", ), "attr_target should be either 'log_prob' or 'prob'" self.attr_target = attr_target def _forward_func( self, perturbed_tensor: Union[None, Tensor], inp: InterpretableInput, target_tokens: Tensor, use_cached_outputs: bool = False, _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, ) -> Tensor: perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) init_model_inp = perturbed_input model_inp = init_model_inp attention_mask = torch.tensor([[1] * model_inp.shape[1]]) attention_mask = attention_mask.to(model_inp.device) model_kwargs = {"attention_mask": attention_mask} log_prob_list = [] outputs = None for target_token in target_tokens: if use_cached_outputs: if outputs is not None: model_kwargs = self.model._update_model_kwargs_for_generation( outputs, model_kwargs ) model_inputs = self.model.prepare_inputs_for_generation( model_inp, **model_kwargs ) outputs = self.model.forward(**model_inputs) else: outputs = self.model.forward(model_inp, attention_mask=attention_mask) new_token_logits = outputs.logits[:, -1] log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) log_prob_list.append(log_probs[0][target_token].detach()) model_inp = torch.cat( (model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1 ) # pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return # Tensor total_log_prob: Tensor = sum(log_prob_list) # type: ignore # 1st element is the total prob, rest are the target tokens # add a leading dim for batch even we only support single instance for now if self.include_per_token_attr: target_log_probs = torch.stack( [total_log_prob, *log_prob_list], dim=0 # type: ignore ).unsqueeze(0) else: target_log_probs = total_log_prob # type: ignore target_probs = torch.exp(target_log_probs) if _inspect_forward: prompt = self.tokenizer.decode(init_model_inp[0]) response = self.tokenizer.decode(target_tokens) # callback for externals to inspect (prompt, response, seq_prob) _inspect_forward(prompt, response, target_probs[0].tolist()) return target_probs if self.attr_target != "log_prob" else target_log_probs def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor: """ Convert str to tokenized tensor to make LLMAttribution work with model inputs of both raw text and text token tensors """ # return tensor(1, n_tokens) if isinstance(model_input, str): # pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be # Tensor return self.tokenizer.encode( # type: ignore model_input, return_tensors="pt" ).to(self.device) return model_input.to(self.device)
[docs] def attribute( self, inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, num_trials: int = 1, gen_args: Optional[Dict[str, Any]] = None, use_cached_outputs: bool = True, # internal callback hook can be used for logging _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, **kwargs: Any, ) -> LLMAttributionResult: """ Args: inp (InterpretableInput): input prompt for which attributions are computed target (str or Tensor, optional): target response with respect to which attributions are computed. If None, it uses the model to generate the target based on the input and gen_args. Default: None num_trials (int, optional): number of trials to run. Return is the average attribibutions over all the trials. Defaults: 1. gen_args (dict, optional): arguments for generating the target. Only used if target is not given. When None, the default arguments are used, {"max_length": 25, "do_sample": False} Defaults: None **kwargs (Any): any extra keyword arguments passed to the call of the underlying attribute function of the given attribution instance Returns: attr (LLMAttributionResult): Attribution result. token_attr will be None if attr method is Lime or KernelShap. """ assert isinstance( inp, self.SUPPORTED_INPUTS ), f"LLMAttribution does not support input type {type(inp)}" if target is None: # generate when None assert hasattr(self.model, "generate") and callable(self.model.generate), ( "The model does not have recognizable generate function." "Target must be given for attribution" ) if not gen_args: gen_args = DEFAULT_GEN_ARGS model_inp = self._format_model_input(inp.to_model_input()) output_tokens = self.model.generate(model_inp, **gen_args) target_tokens = output_tokens[0][model_inp.size(1) :] else: assert gen_args is None, "gen_args must be None when target is given" if type(target) is str: # exclude sos target_tokens = self.tokenizer.encode(target)[1:] target_tokens = torch.tensor(target_tokens) elif type(target) is torch.Tensor: target_tokens = target else: raise TypeError( "target must either be str or Tensor, but the type of target is " "{}".format(type(target)) ) attr = torch.zeros( [ 1 + len(target_tokens) if self.include_per_token_attr else 1, inp.n_itp_features, ], dtype=torch.float, device=self.device, ) for _ in range(num_trials): attr_input = inp.to_tensor().to(self.device) cur_attr = self.attr_method.attribute( attr_input, additional_forward_args=( inp, target_tokens, use_cached_outputs, _inspect_forward, ), **kwargs, ) # temp necessary due to FA & Shapley's different return shape of multi-task # FA will flatten output shape internally (n_output_token, n_itp_features) # Shapley will keep output shape (batch, n_output_token, n_input_features) cur_attr = cur_attr.reshape(attr.shape) attr += cur_attr attr = attr / num_trials attr = inp.format_attr(attr) return LLMAttributionResult( attr[0], ( attr[1:] if self.include_per_token_attr else None ), # shape(n_output_token, n_input_features) inp.values, self.tokenizer.convert_ids_to_tokens(target_tokens), )
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
[docs] def attribute_future(self) -> Callable: r""" This method is not implemented for LLMAttribution. """ raise NotImplementedError( "attribute_future is not implemented for LLMAttribution" )
[docs] class LLMGradientAttribution(Attribution): """ Attribution class for large language models. It wraps a gradient-based attribution algorthm to produce commonly interested attribution results for the use case of text generation. The wrapped instance will calculate attribution in the same way as configured in the original attribution algorthm, with respect to the log probabilities of each generated token and the whole sequence. It will provide a new "attribute" function which accepts text-based inputs and returns LLMAttributionResult """ SUPPORTED_METHODS = ( LayerGradientShap, LayerGradientXActivation, LayerIntegratedGradients, ) SUPPORTED_INPUTS = (TextTokenInput,) def __init__( self, attr_method: GradientAttribution, tokenizer: TokenizerLike, ) -> None: """ Args: attr_method (Attribution): instance of a supported perturbation attribution class created with the llm model that follows huggingface style interface convention tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method """ assert isinstance( attr_method, self.SUPPORTED_METHODS ), f"LLMGradientAttribution does not support {type(attr_method)}" super().__init__(attr_method.forward_func) # alias, we really need a model and don't support wrapper functions # coz we need call model.forward, model.generate, etc. self.model: nn.Module = cast(nn.Module, self.forward_func) # shallow copy is enough to avoid modifying original instance self.attr_method: GradientAttribution = copy(attr_method) self.attr_method.forward_func = GradientForwardFunc(self) self.tokenizer: TokenizerLike = tokenizer self.device: torch.device = ( cast(torch.device, self.model.device) if hasattr(self.model, "device") else next(self.model.parameters()).device ) def _format_model_input(self, model_input: Tensor) -> Tensor: """ Convert str to tokenized tensor """ return model_input.to(self.device)
[docs] def attribute( self, inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, gen_args: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> LLMAttributionResult: """ Args: inp (InterpretableInput): input prompt for which attributions are computed target (str or Tensor, optional): target response with respect to which attributions are computed. If None, it uses the model to generate the target based on the input and gen_args. Default: None gen_args (dict, optional): arguments for generating the target. Only used if target is not given. When None, the default arguments are used, {"max_length": 25, "do_sample": False} Defaults: None **kwargs (Any): any extra keyword arguments passed to the call of the underlying attribute function of the given attribution instance Returns: attr (LLMAttributionResult): attribution result """ assert isinstance( inp, self.SUPPORTED_INPUTS ), f"LLMGradAttribution does not support input type {type(inp)}" if target is None: # generate when None assert hasattr(self.model, "generate") and callable(self.model.generate), ( "The model does not have recognizable generate function." "Target must be given for attribution" ) if not gen_args: gen_args = DEFAULT_GEN_ARGS model_inp = self._format_model_input(inp.to_model_input()) output_tokens = self.model.generate(model_inp, **gen_args) target_tokens = output_tokens[0][model_inp.size(1) :] else: assert gen_args is None, "gen_args must be None when target is given" if type(target) is str: # exclude sos target_tokens = self.tokenizer.encode(target)[1:] target_tokens = torch.tensor(target_tokens) elif type(target) is torch.Tensor: target_tokens = target else: raise TypeError( "target must either be str or Tensor, but the type of target is " "{}".format(type(target)) ) attr_inp = inp.to_tensor().to(self.device) attr_list = [] for cur_target_idx, _ in enumerate(target_tokens): # attr in shape(batch_size, input+output_len, emb_dim) attr = self.attr_method.attribute( attr_inp, additional_forward_args=( inp, target_tokens, cur_target_idx, ), **kwargs, ) attr = cast(Tensor, attr) # will have the attr for previous output tokens # cut to shape(batch_size, inp_len, emb_dim) if cur_target_idx: attr = attr[:, :-cur_target_idx] # the author of IG uses sum # https://github.com/ankurtaly/Integrated-Gradients/blob/master/BertModel/bert_model_utils.py#L350 attr = attr.sum(-1) attr_list.append(attr) # assume inp batch only has one instance # to shape(n_output_token, ...) attr = torch.cat(attr_list, dim=0) # grad attr method do not care the length of features in interpretable format # it attributes to all the elements of the output of the specified layer # so we need special handling for the inp type which don't care all the elements if isinstance(inp, TextTokenInput) and inp.itp_mask is not None: itp_mask = inp.itp_mask.to(attr.device) itp_mask = itp_mask.expand_as(attr) attr = attr[itp_mask].view(attr.size(0), -1) # for all the gradient methods we support in this class # the seq attr is the sum of all the token attr if the attr_target is log_prob, # shape(n_input_features) seq_attr = attr.sum(0) return LLMAttributionResult( seq_attr, attr, # shape(n_output_token, n_input_features) inp.values, self.tokenizer.convert_ids_to_tokens(target_tokens), )
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
[docs] def attribute_future(self) -> Callable: r""" This method is not implemented for LLMGradientAttribution. """ raise NotImplementedError( "attribute_future is not implemented for LLMGradientAttribution" )
class GradientForwardFunc(nn.Module): """ A wrapper class for the forward function of a model in LLMGradientAttribution """ def __init__(self, attr: LLMGradientAttribution) -> None: super().__init__() self.attr = attr self.model: nn.Module = attr.model def forward( self, perturbed_tensor: Tensor, inp: InterpretableInput, target_tokens: Tensor, # 1D tensor of target token ids cur_target_idx: int, # current target index ) -> Tensor: perturbed_input = self.attr._format_model_input( inp.to_model_input(perturbed_tensor) ) if cur_target_idx: # the input batch size can be expanded by attr method output_token_tensor = ( target_tokens[:cur_target_idx] .unsqueeze(0) .expand(perturbed_input.size(0), -1) .to(self.attr.device) ) new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1) else: new_input_tensor = perturbed_input output_logits = self.model(new_input_tensor) new_token_logits = output_logits.logits[:, -1] log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) target_token = target_tokens[cur_target_idx] token_log_probs = log_probs[..., target_token] # the attribution target is limited to the log probability return token_log_probs