In this tutorial, we demonstrate how to use Captum's ImageMaskInput to perform image segment saliency analysis on Multi-Modal Large Language Models (MM-LLMs). This allows us to understand which regions of an input image contribute most to the model's generated text response.
ImageMaskInput to define interpretable image segmentsFeatureAblation with LLMAttribution for MM-LLMsFeature Ablation works by systematically "masking out" (ablating) different image segments and measuring how the model's output changes. Segments that cause large changes when removed are considered more important for the model's prediction.
ImageMaskInput is Captum's adapter that:
In this tutorial, we use Google Gemma-4 as our multi-modal LLM, but the approach works with any vision-language model that follows the standard processor pattern.
First, let's import the necessary packages. We need:
FeatureAblation, LLMAttribution, ImageMaskInput)import warnings
warnings.filterwarnings("ignore")
from io import BytesIO
import numpy as np
import requests
import torch
import transformers
# Captum imports for attribution
from captum.attr import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._utils.interpretable_input import ImageMaskInput
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Transformers version: {transformers.__version__}")
We use Google Gemma-4-31B-it, a powerful multi-modal instruction-tuned model. To fit the model in GPU memory, we use 4-bit quantization via BitsAndBytesConfig.
The model can process both text and images, making it ideal for demonstrating image attribution in multi-modal contexts.
# Model configuration
model_id = "google/gemma-4-31B-it"
# Load the model and processor
print("Loading processor...")
processor = AutoProcessor.from_pretrained(model_id)
print("Loading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print(f"\nModel loaded successfully!")
print(f"Model device: {model.device}")
Let's load a sample image that we'll use to demonstrate image attribution. We'll use a classic image of a Volkswagen Beetle car.
def load_image_from_url(url):
"""Load an image from a URL."""
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
# Load a sample image from HuggingFace
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image_from_url(image_url)
print(f"Image loaded successfully!")
print(f"Image size: {image.size} (width x height)")
# Display the image
image
Before performing attribution, let's first see how the model responds to a simple image description prompt. This will help us understand what output we're trying to explain.
# Define the conversation with image and text
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Briefly describe the image in one sentence."},
],
}
]
# Apply the chat template to format the prompt
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
# Process the inputs (image + text) for the model
inputs = processor(
text=prompt,
images=image,
return_tensors="pt",
).to(model.device)
print("Input prepared successfully!")
print(f"Input token shape: {inputs['input_ids'].shape}")
# Generate the model's response
print("Generating response...")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False, # Greedy decoding for reproducibility
)
# Decode only the generated tokens (exclude the input)
input_token_len = inputs["input_ids"].shape[1]
response = processor.decode(output[0][input_token_len:], skip_special_tokens=True)
print("\n" + "=" * 60)
print("Model Response:")
print("=" * 60)
print(response)
To use ImageMaskInput for attribution, we need to define how to segment the image into interpretable features. We'll demonstrate two approaches:
The simplest approach is to divide the image into a uniform m×n grid. This is fast and requires no external models, but segments may not align with semantic boundaries.
def create_grid_mask(image, rows, cols):
"""
Create a grid mask that divides the image into rows x cols segments.
Args:
image: PIL Image
rows: Number of rows in the grid
cols: Number of columns in the grid
Returns:
Tensor of shape (height, width) with integer segment IDs
"""
width, height = image.size
mask = torch.zeros((height, width), dtype=torch.int32)
h_step = height // rows
w_step = width // cols
for row in range(rows):
for col in range(cols):
# Calculate segment boundaries
y_start = row * h_step
y_end = height if row == rows - 1 else (row + 1) * h_step
x_start = col * w_step
x_end = width if col == cols - 1 else (col + 1) * w_step
# Assign unique segment ID
segment_id = row * cols + col
mask[y_start:y_end, x_start:x_end] = segment_id
return mask
# Create a 4x5 grid mask (20 segments)
grid_mask = create_grid_mask(image, rows=4, cols=5)
print(f"Grid mask shape: {grid_mask.shape}")
print(f"Number of segments: {len(torch.unique(grid_mask))}")
ImageMaskInput provides a convenient plot_mask_overlay() method to visualize segmentations. Here we create an ImageMaskInput object with a dummy processor function just to leverage this visualization utility - the actual attribution will use a proper processor function defined later.
Note: The
processor_fnparameter is required byImageMaskInput, but for visualization purposes only, we can use a simple identity function (lambda x: x).
# Create ImageMaskInput with grid mask for VISUALIZATION ONLY
# We use a dummy processor (identity function) since we just want to preview the segmentation
# The actual ImageMaskInput for attribution will be created later with a proper processor_fn
grid_input_preview = ImageMaskInput(
image=image,
mask=grid_mask,
processor_fn=lambda x: x, # Dummy processor - just for visualization
)
print("Grid segmentation (4 rows × 5 columns = 20 segments):")
grid_input_preview.plot_mask_overlay(show=True)
For more semantically meaningful segments, we can use SAM-2 (Segment Anything Model 2) to automatically detect objects and regions in the image. SAM-2 produces masks that follow object boundaries, making the attribution results more interpretable.
from transformers import pipeline
# Load SAM-2 for automatic mask generation
sam_generator = pipeline(
"mask-generation", model="facebook/sam2-hiera-large", device=0
)
print("SAM-2 model loaded successfully!")
# Generate semantic masks using SAM-2
# points_per_batch controls memory usage during mask generation
sam_outputs = sam_generator(image, points_per_batch=16)
# Extract the list of binary masks
sam_masks = sam_outputs["masks"]
print(f"SAM-2 generated {len(sam_masks)} semantic segments")
print(f"Each mask shape: {sam_masks[0].shape}")
# Similarly, create ImageMaskInput with SAM-2 masks for VISUALIZATION ONLY
# Notice how SAM-2 segments follow semantic boundaries (car body, windows, wheels, etc.)
sam_input_preview = ImageMaskInput(
image=image,
mask_list=sam_masks,
processor_fn=lambda x: x, # Dummy processor - just for visualization
)
print(f"SAM-2 segmentation ({len(sam_masks)} semantic segments):")
sam_input_preview.plot_mask_overlay(show=True)
Now we'll set up Captum's attribution tools:
# Initialize FeatureAblation with the model
fa = FeatureAblation(model)
# Wrap it with LLMAttribution for handling LLM-specific complexities
llm_attr = LLMAttribution(fa, processor.tokenizer)
print("Attribution tools initialized!")
The key component is the processor_fn - a function that converts an image into the model's expected input format. This function is called during attribution for each perturbed version of the image.
def processor_fn(img):
"""
Convert an image to model inputs.
This function is called by ImageMaskInput for each perturbed image
during the attribution process.
Args:
img: PIL Image (potentially with masked regions)
Returns:
Model inputs (tokenized text + processed image)
"""
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Briefly describe the image in one sentence."},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
return processor(
text=prompt,
images=img,
return_tensors="pt",
).to(model.device)
print("Processor function defined!")
Now we create an ImageMaskInput using the SAM-2 generated masks and run attribution. The mask_list parameter accepts a list of binary masks, which is exactly what SAM-2 provides.
# Create ImageMaskInput with SAM-2 masks
sam_input = ImageMaskInput(
image=image,
mask_list=sam_masks, # List of binary masks from SAM-2
processor_fn=processor_fn,
)
print(f"Created ImageMaskInput with {sam_input.n_itp_features} interpretable features (segments)")
# Run attribution - this will ablate each segment and measure impact on output
print("\nRunning attribution (this may take a few minutes)...")
sam_attr_result = llm_attr.attribute(sam_input, forward_in_tokens=False)
print("Attribution complete!")
Now that attribution is complete, let's visualize the results to understand which image regions contributed to the model's response.
The plot_token_attr method shows the attribution for each generated token. This helps us understand which image segments contributed to each word in the model's response:
# Plot token-level attribution
# This shows which segments are important for each generated token
sam_attr_result.plot_token_attr(show=True)
The most intuitive visualization is the image heatmap, which shows the attribution overlaid directly on the image. Brighter/warmer colors indicate regions that are more important for the model's response:
# Plot overall image attribution heatmap
# This aggregates attribution across all generated tokens
sam_attr_result.plot_image_heatmap(show=True)
We can also visualize which image regions contributed to specific tokens in the response. Use target_token_pos to specify which token(s) to focus on.
This is useful for understanding why the model mentioned specific concepts in its description.
# First, let's see which tokens were generated
print("Generated tokens:")
for i, token in enumerate(sam_attr_result.output_tokens):
print(f" Position {i}: '{token}'")
# Example: Attribution heatmap for specific tokens
# Adjust the token positions based on the output above
# Heatmap for "Volkswagen Beetle" tokens
target_token_pos = (4, 6)
print(
f'Showing heatmap for tokens at positions {target_token_pos} ["Volkswagen Beetle"]'
)
sam_attr_result.plot_image_heatmap(show=True, target_token_pos=target_token_pos)
# Heatmap for "two brown wooden doors" tokens (adjust positions as needed)
target_token_pos = (15, 19)
print(
f'Showing heatmap for tokens at positions {target_token_pos} ["two brown wooden doors"]'
)
sam_attr_result.plot_image_heatmap(show=True, target_token_pos=target_token_pos)
For comparison, let's also run attribution using the simple grid-based segmentation. This is faster but may produce less interpretable results since grid cells don't follow object boundaries.
# Create ImageMaskInput with grid mask for ATTRIBUTION (with proper processor_fn)
grid_input = ImageMaskInput(
image=image,
mask=grid_mask,
processor_fn=processor_fn, # Now using the real processor function
)
# Run attribution with grid-based segmentation
print(f"Running grid-based attribution with {grid_input.n_itp_features} segments...")
grid_attr_result = llm_attr.attribute(grid_input, forward_in_tokens=False)
print("Grid attribution complete!")
# Compare: visualize grid-based heatmap
print("\nGrid-based attribution heatmap:")
grid_attr_result.plot_image_heatmap(show=True)
In this tutorial, we demonstrated how to use Captum's ImageMaskInput for image segment saliency analysis on multi-modal LLMs.
ImageMaskInput provides a bridge between image segmentation and Captum's attribution algorithms, allowing you to understand which parts of an image influence a model's text output.
Two segmentation approaches:
Visualization options:
plot_mask_overlay(): Visualize the segmentationplot_token_attr(): See per-token attributionplot_image_heatmap(): Overlay attribution on the imageplot_image_heatmap(target_token_pos=...): Focus on specific tokens