Captum
  • Docs
  • Tutorials
  • API Reference
  • GitHub

›

Captum Tutorials

  • Overview

Introduction to Captum

  • Getting started with Captum

Attribution

  • Interpreting text models
  • Intepreting vision with CIFAR
  • Interpreting vision with Pretrained Models
  • Feature ablation on images with ResNet
  • Interpreting multimodal models
  • Interpreting a regression model of California house prices
  • Interpreting semantic segmentation models
  • Using Captum with torch.distributed
  • Interpreting Deep Learning Recommender Models
  • Interpreting vision and text models with LIME
  • Understanding Llama2 with Captum LLM Attribution
  • Image Segment Attribution for Multi-Modal LLMs
  • Interpreting BERT

    • Interpreting question answering with BERT Part 1
    • Interpreting question answering with BERT Part 2

Robustness

  • Applying robustness attacks and metrics to CIFAR model and dataset

Concept

  • TCAV for image classification for googlenet model
  • TCAV for NLP sentiment analysis model

Influential Examples

  • Identifying influential examples and mis-labelled examples with TracInCP

Image Segment Attribution for Multi-Modal LLMs with Captum¶

In this tutorial, we demonstrate how to use Captum's ImageMaskInput to perform image segment saliency analysis on Multi-Modal Large Language Models (MM-LLMs). This allows us to understand which regions of an input image contribute most to the model's generated text response.

What You'll Learn¶

  • How to use ImageMaskInput to define interpretable image segments
  • Two segmentation approaches:
    • Grid-based segmentation: Simple, uniform division of the image
    • SAM-2 segmentation: Semantically meaningful segments using Segment Anything Model 2
  • How to apply FeatureAblation with LLMAttribution for MM-LLMs
  • How to visualize attribution results as heatmaps

Key Concepts¶

Feature Ablation works by systematically "masking out" (ablating) different image segments and measuring how the model's output changes. Segments that cause large changes when removed are considered more important for the model's prediction.

ImageMaskInput is Captum's adapter that:

  1. Takes an image and segmentation mask
  2. Converts it to an interpretable representation (presence/absence of each segment)
  3. Handles the perturbation during attribution (replacing masked regions with a baseline color)

In this tutorial, we use Google Gemma-4 as our multi-modal LLM, but the approach works with any vision-language model that follows the standard processor pattern.

Step 1: Setup and Imports¶

First, let's import the necessary packages. We need:

  • PyTorch and Transformers: For loading and running the MM-LLM model
  • Captum: For attribution algorithms (FeatureAblation, LLMAttribution, ImageMaskInput)
  • PIL and Matplotlib: For image handling and visualization
  • SAM-2: For semantic image segmentation (Segment Anything Model 2)
In [ ]:
import warnings

warnings.filterwarnings("ignore")

from io import BytesIO

import numpy as np
import requests
import torch
import transformers


# Captum imports for attribution
from captum.attr import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._utils.interpretable_input import ImageMaskInput
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Transformers version: {transformers.__version__}")

Step 2: Load the Multi-Modal LLM¶

We use Google Gemma-4-31B-it, a powerful multi-modal instruction-tuned model. To fit the model in GPU memory, we use 4-bit quantization via BitsAndBytesConfig.

The model can process both text and images, making it ideal for demonstrating image attribution in multi-modal contexts.

In [ ]:
# Model configuration
model_id = "google/gemma-4-31B-it"

# Load the model and processor
print("Loading processor...")
processor = AutoProcessor.from_pretrained(model_id)

print("Loading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

print(f"\nModel loaded successfully!")
print(f"Model device: {model.device}")

Step 3: Load a Sample Image¶

Let's load a sample image that we'll use to demonstrate image attribution. We'll use a classic image of a Volkswagen Beetle car.

In [ ]:
def load_image_from_url(url):
    """Load an image from a URL."""
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    return image


# Load a sample image from HuggingFace
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image_from_url(image_url)

print(f"Image loaded successfully!")
print(f"Image size: {image.size} (width x height)")

# Display the image
image

Step 4: Generate a Baseline Response¶

Before performing attribution, let's first see how the model responds to a simple image description prompt. This will help us understand what output we're trying to explain.

In [ ]:
# Define the conversation with image and text
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Briefly describe the image in one sentence."},
        ],
    }
]

# Apply the chat template to format the prompt
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

# Process the inputs (image + text) for the model
inputs = processor(
    text=prompt,
    images=image,
    return_tensors="pt",
).to(model.device)

print("Input prepared successfully!")
print(f"Input token shape: {inputs['input_ids'].shape}")
In [ ]:
# Generate the model's response
print("Generating response...")

with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=False,  # Greedy decoding for reproducibility
    )

# Decode only the generated tokens (exclude the input)
input_token_len = inputs["input_ids"].shape[1]
response = processor.decode(output[0][input_token_len:], skip_special_tokens=True)

print("\n" + "=" * 60)
print("Model Response:")
print("=" * 60)
print(response)

Step 5: Image Segmentation Methods¶

To use ImageMaskInput for attribution, we need to define how to segment the image into interpretable features. We'll demonstrate two approaches:

  1. Grid-based segmentation: Simple, uniform division of the image into a grid
  2. SAM-2 segmentation: Semantically meaningful segments using the Segment Anything Model 2

Part A: Grid-Based Segmentation¶

The simplest approach is to divide the image into a uniform m×n grid. This is fast and requires no external models, but segments may not align with semantic boundaries.

In [ ]:
def create_grid_mask(image, rows, cols):
    """
    Create a grid mask that divides the image into rows x cols segments.
    
    Args:
        image: PIL Image
        rows: Number of rows in the grid
        cols: Number of columns in the grid
    
    Returns:
        Tensor of shape (height, width) with integer segment IDs
    """
    width, height = image.size
    mask = torch.zeros((height, width), dtype=torch.int32)
    
    h_step = height // rows
    w_step = width // cols
    
    for row in range(rows):
        for col in range(cols):
            # Calculate segment boundaries
            y_start = row * h_step
            y_end = height if row == rows - 1 else (row + 1) * h_step
            x_start = col * w_step
            x_end = width if col == cols - 1 else (col + 1) * w_step
            
            # Assign unique segment ID
            segment_id = row * cols + col
            mask[y_start:y_end, x_start:x_end] = segment_id
    
    return mask


# Create a 4x5 grid mask (20 segments)
grid_mask = create_grid_mask(image, rows=4, cols=5)
print(f"Grid mask shape: {grid_mask.shape}")
print(f"Number of segments: {len(torch.unique(grid_mask))}")

Visualizing the Segmentation¶

ImageMaskInput provides a convenient plot_mask_overlay() method to visualize segmentations. Here we create an ImageMaskInput object with a dummy processor function just to leverage this visualization utility - the actual attribution will use a proper processor function defined later.

Note: The processor_fn parameter is required by ImageMaskInput, but for visualization purposes only, we can use a simple identity function (lambda x: x).

In [ ]:
# Create ImageMaskInput with grid mask for VISUALIZATION ONLY
# We use a dummy processor (identity function) since we just want to preview the segmentation
# The actual ImageMaskInput for attribution will be created later with a proper processor_fn
grid_input_preview = ImageMaskInput(
    image=image,
    mask=grid_mask,
    processor_fn=lambda x: x,  # Dummy processor - just for visualization
)

print("Grid segmentation (4 rows × 5 columns = 20 segments):")
grid_input_preview.plot_mask_overlay(show=True)

Part B: SAM-2 Semantic Segmentation¶

For more semantically meaningful segments, we can use SAM-2 (Segment Anything Model 2) to automatically detect objects and regions in the image. SAM-2 produces masks that follow object boundaries, making the attribution results more interpretable.

In [ ]:
from transformers import pipeline

# Load SAM-2 for automatic mask generation
sam_generator = pipeline(
    "mask-generation", model="facebook/sam2-hiera-large", device=0
)

print("SAM-2 model loaded successfully!")
In [ ]:
# Generate semantic masks using SAM-2
# points_per_batch controls memory usage during mask generation
sam_outputs = sam_generator(image, points_per_batch=16)

# Extract the list of binary masks
sam_masks = sam_outputs["masks"]

print(f"SAM-2 generated {len(sam_masks)} semantic segments")
print(f"Each mask shape: {sam_masks[0].shape}")
In [ ]:
# Similarly, create ImageMaskInput with SAM-2 masks for VISUALIZATION ONLY
# Notice how SAM-2 segments follow semantic boundaries (car body, windows, wheels, etc.)
sam_input_preview = ImageMaskInput(
    image=image,
    mask_list=sam_masks,
    processor_fn=lambda x: x,  # Dummy processor - just for visualization
)

print(f"SAM-2 segmentation ({len(sam_masks)} semantic segments):")
sam_input_preview.plot_mask_overlay(show=True)

Step 6: Setup Captum Attribution¶

Now we'll set up Captum's attribution tools:

  • FeatureAblation: The underlying perturbation-based attribution algorithm
  • LLMAttribution: A wrapper that handles the complexities of LLM attribution (sequential token generation, probability aggregation, etc.)
In [ ]:
# Initialize FeatureAblation with the model
fa = FeatureAblation(model)

# Wrap it with LLMAttribution for handling LLM-specific complexities
llm_attr = LLMAttribution(fa, processor.tokenizer)

print("Attribution tools initialized!")

Step 7: Create ImageMaskInput for Attribution¶

The key component is the processor_fn - a function that converts an image into the model's expected input format. This function is called during attribution for each perturbed version of the image.

In [ ]:
def processor_fn(img):
    """
    Convert an image to model inputs.
    
    This function is called by ImageMaskInput for each perturbed image
    during the attribution process.
    
    Args:
        img: PIL Image (potentially with masked regions)
    
    Returns:
        Model inputs (tokenized text + processed image)
    """
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": "Briefly describe the image in one sentence."},
            ],
        }
    ]
    
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    
    return processor(
        text=prompt,
        images=img,
        return_tensors="pt",
    ).to(model.device)


print("Processor function defined!")

Step 8: Run Attribution with SAM-2 Segmentation¶

Now we create an ImageMaskInput using the SAM-2 generated masks and run attribution. The mask_list parameter accepts a list of binary masks, which is exactly what SAM-2 provides.

In [ ]:
# Create ImageMaskInput with SAM-2 masks
sam_input = ImageMaskInput(
    image=image,
    mask_list=sam_masks,  # List of binary masks from SAM-2
    processor_fn=processor_fn,
)

print(f"Created ImageMaskInput with {sam_input.n_itp_features} interpretable features (segments)")

# Run attribution - this will ablate each segment and measure impact on output
print("\nRunning attribution (this may take a few minutes)...")
sam_attr_result = llm_attr.attribute(sam_input, forward_in_tokens=False)

print("Attribution complete!")

Step 9: Visualize Attribution Results¶

Now that attribution is complete, let's visualize the results to understand which image regions contributed to the model's response.

Token-Level Attribution¶

The plot_token_attr method shows the attribution for each generated token. This helps us understand which image segments contributed to each word in the model's response:

In [ ]:
# Plot token-level attribution
# This shows which segments are important for each generated token
sam_attr_result.plot_token_attr(show=True)

Image Heatmap Visualization¶

The most intuitive visualization is the image heatmap, which shows the attribution overlaid directly on the image. Brighter/warmer colors indicate regions that are more important for the model's response:

In [ ]:
# Plot overall image attribution heatmap
# This aggregates attribution across all generated tokens
sam_attr_result.plot_image_heatmap(show=True)

Token-Specific Heatmaps¶

We can also visualize which image regions contributed to specific tokens in the response. Use target_token_pos to specify which token(s) to focus on.

This is useful for understanding why the model mentioned specific concepts in its description.

In [ ]:
# First, let's see which tokens were generated
print("Generated tokens:")
for i, token in enumerate(sam_attr_result.output_tokens):
    print(f"  Position {i}: '{token}'")
In [ ]:
# Example: Attribution heatmap for specific tokens
# Adjust the token positions based on the output above

# Heatmap for "Volkswagen Beetle" tokens
target_token_pos = (4, 6)
print(
    f'Showing heatmap for tokens at positions {target_token_pos} ["Volkswagen Beetle"]'
)
sam_attr_result.plot_image_heatmap(show=True, target_token_pos=target_token_pos)
In [ ]:
# Heatmap for "two brown wooden doors" tokens (adjust positions as needed)
target_token_pos = (15, 19)
print(
    f'Showing heatmap for tokens at positions {target_token_pos} ["two brown wooden doors"]'
)
sam_attr_result.plot_image_heatmap(show=True, target_token_pos=target_token_pos)

Step 10: Comparison - Grid-Based Attribution¶

For comparison, let's also run attribution using the simple grid-based segmentation. This is faster but may produce less interpretable results since grid cells don't follow object boundaries.

In [ ]:
# Create ImageMaskInput with grid mask for ATTRIBUTION (with proper processor_fn)
grid_input = ImageMaskInput(
    image=image,
    mask=grid_mask,
    processor_fn=processor_fn,  # Now using the real processor function
)

# Run attribution with grid-based segmentation
print(f"Running grid-based attribution with {grid_input.n_itp_features} segments...")
grid_attr_result = llm_attr.attribute(grid_input, forward_in_tokens=False)

print("Grid attribution complete!")

# Compare: visualize grid-based heatmap
print("\nGrid-based attribution heatmap:")
grid_attr_result.plot_image_heatmap(show=True)

Conclusion¶

In this tutorial, we demonstrated how to use Captum's ImageMaskInput for image segment saliency analysis on multi-modal LLMs.

Key Takeaways¶

  1. ImageMaskInput provides a bridge between image segmentation and Captum's attribution algorithms, allowing you to understand which parts of an image influence a model's text output.

  2. Two segmentation approaches:

    • Grid-based: Fast and simple, but segments don't follow semantic boundaries
    • SAM-2: Produces semantically meaningful segments that follow object boundaries, leading to more interpretable results
  3. Visualization options:

    • plot_mask_overlay(): Visualize the segmentation
    • plot_token_attr(): See per-token attribution
    • plot_image_heatmap(): Overlay attribution on the image
    • plot_image_heatmap(target_token_pos=...): Focus on specific tokens

Next Steps¶

  • Try different segmentation granularities (more/fewer grid cells, different SAM settings)
  • Experiment with different prompts to see how image regions influence different types of responses
  • Apply this technique to your own multi-modal models and use cases

References¶

  • Captum Documentation
  • LLM Attribution Paper
  • Segment Anything Model 2
Download Tutorial Jupyter Notebook
Download Tutorial Source Code
Docs
IntroductionGetting StartedTutorialsAPI Reference
Legal
PrivacyTerms
Social
captum
Facebook Open Source
Copyright © 2026 Facebook Inc.