Interpreting CLIP’s Image Representation via Text-Based Decomposition
Authors: Yossi Gandelsman, Alexei A. Efros, Jacob Steinhardt
Abstract: We investigate the CLIP image encoder by analyzing how individual model components affect the final representation. We decompose the image representation as a sum across individual image patches, model layers, and attention heads, and use CLIP’s text representation to interpret the summands. Interpreting the attention heads, we characterize each head’s role by automatically finding text representations that span its output space, which reveals property-specific roles for many heads (e.g. location or shape). Next, interpreting the image patches, we uncover an emergent spatial localization within CLIP. Finally, we use this understanding to remove spurious features from CLIP and to create a strong zero-shot image segmenter. Our results indicate that a scalable understanding of transformer models is attainable and can be used to repair and improve models.
What, Why and How
Here is a summary of the key points from this paper:
What:
-
The paper investigates the image encoder of CLIP (Contrastive Language-Image Pre-training) by analyzing how individual components of the model affect the final image representation.
-
It focuses on the variant of CLIP that uses a Vision Transformer (ViT) as the image encoder.
-
The authors propose methods to decompose the image representation into contributions from individual image patches, model layers, and attention heads.
Why:
-
Understanding what information is encoded in CLIP’s image representations can help improve the model and enable new applications.
-
Decomposing the representation allows attributing parts of it to specific model components, revealing their roles.
How:
-
The residual connections in ViT allow decomposing the output into direct contributions of each layer. Experiments show only the last layers contribute significantly.
-
An algorithm called TextSpan is proposed to find text descriptions that span the space of each attention head. This reveals some heads capture specific properties like color or shape.
-
The attention mechanism allows decomposing the output by image patches, revealing an emergent spatial localization in CLIP’s representation.
-
These decompositions are used to reduce spurious correlations in CLIP and improve its zero-shot segmentation capability.
In summary, this work provides methods to interpret CLIP’s image representation by decomposing it into interpretable elements attributed to model components and image regions. This reveals new structure that enables improving CLIP’s capabilities.
Main Contributions
Based on my reading, the main contributions of this paper are:
-
Proposing methods to decompose CLIP’s image representation into contributions of individual model components like layers, heads, and image patches.
-
Discovering through these decompositions that:
- Only the last few transformer layers contribute directly to the output.
- Some attention heads capture specific semantic properties like colors, shapes, etc.
- There is an emergent spatial localization in the representation across image patches.
-
Using these discoveries to:
- Label directions in attention heads with descriptive text using the TextSpan algorithm.
- Reduce reliance on spurious correlations in classification.
- Improve CLIP’s zero-shot segmentation over prior methods.
-
Demonstrating that decomposing and interpreting model representations can reveal useful structure and enable improving model capabilities.
-
Providing an analysis framework and insights into what is encoded in CLIP’s image representations.
So in summary, the key contributions are proposing representation decomposition techniques, applying them to discover interpretable structure in CLIP, and using these discoveries to improve CLIP’s performance on downstream tasks. The analysis provides new understanding of what CLIP learns and how decomposing representations can be useful more broadly.
Method Section
Here is a summary of the methods section from the paper:
Decomposing into Layers:
- The residual connections in ViT allow expressing the output as a sum of contributions from each layer.
- They decompose the image representation into direct contributions of the MLPs, MSAs, and input patch embeddings.
- Ablation experiments show only the last 4 MSA layers contribute significantly.
Decomposing into Attention Heads:
- The MSA output can be decomposed as a sum over contributions of each attention head.
- They propose an algorithm called TextSpan that finds text descriptions that maximize variance explained for each head’s output space.
- This results in interpretive text labels for directions in each head.
Decomposing into Image Patches:
- The attention mechanism allows decomposing the output as a sum over image patch contributions.
- This spatial decomposition highlights which image regions contribute along a text description.
- It enables zero-shot segmentation by thresholding similarity heatmaps between image patches and class descriptions.
Joint Decomposition:
- They combine the head and patch decompositions to attribute contributions to text directions back to input image regions.
- This validates that image areas highlighted match the textual semantics.
In summary, the core method is decomposing CLIP’s representation in multiple ways using its architectural properties - into layers, heads, and patches. This reveals interpretable structure that is analyzed to understand and improve CLIP.
High-Level Pseudo Code
Here is some high-level pseudocode for the key algorithms in the paper:
# Decompose image representation into layers
image_rep = input_embedding
for layer in transformer_layers:
image_rep += layer(image_rep)
# Decompose MSA output into heads
msa_output = 0
for head in msa_heads:
attn_out = attend(query, key, value, head)
msa_output += attn_out[cls_token]
# TextSpan algorithm
def textspan(head_outputs):
text_pool = get_text_pool()
basis = []
for i in range(num_basis):
text_proj = project_text(text_pool)
sim = dot_product(text_proj, head_outputs)
text = argmax(variance(sim))
basis.append(text)
residual = head_outputs
head_outputs -= proj(head_outputs, text)
return basis
# Decompose into image patches
image_rep = 0
for patch in image_patches:
attn_weights = attend(patch, cls_token)
patch_rep = embed(patch)
image_rep += attn_weights * patch_rep
# Joint decomposition
for head in heads:
for patch in patches:
patch_head_sim = dot_product(head[patch], text)
visualize(patch, patch_head_sim)
This shows the core ideas for decomposing the representation into layers, heads, patches, and jointly. TextSpan finds a textual basis for each head’s output space. The joint decomposition attributes contributions for basis vectors back to input patches.
Detailed Pseudo Code
Here is some more detailed pseudocode to implement the key algorithms from the paper:
# TextSpan
def textspan(head_outputs, text_pool, num_basis):
basis = []
# Project text pool to output space
text_projs = []
for text in text_pool:
text_embed = embed_text(text)
text_proj = project(text_embed, head_outputs)
text_projs.append(text_proj)
residuals = head_outputs
for i in range(num_basis):
# Compute similarities
sims = []
for text_proj in text_projs:
sim = dot_product(text_proj, residuals)
sims.append(sim)
# Find text with max variance
variances = [np.var(sim) for sim in sims]
text_index = np.argmax(variances)
basis_text = text_pool[text_index]
basis_vec = text_projs[text_index]
# Update residuals
residuals -= project(residuals, basis_vec)
text_projs -= project(text_projs, basis_vec)
# Add to basis
basis.append((basis_text, basis_vec))
return basis
# Image patch decomposition
def patch_decomposition(image, class_text):
# Embed text
text_embed = embed_text(class_text)
# Run image through CLIP encoder
patch_embeds = patchify(image) # Split into patches
attn_weights = attend(patch_embeds, cls_token) # Attend patches to cls token
patch_outs = [w * patch_embed for w, patch_embed in zip(attn_weights, patch_embeds)]
# Compute similarity map
heatmap = []
for i, patch_out in enumerate(patch_outs):
sim = dot_product(patch_out, text_embed)
heatmap.append((i, sim))
return heatmap
# Joint decomposition
def joint_decomposition(image, basis):
patch_sim_maps = []
for text, basis_vec in basis:
patch_outs = patch_decomposition(image, text)
patch_sim_map = [(p, s*basis_vec[p]) for (p,s) in patch_outs]
patch_sim_maps.append(patch_sim_map)
return patch_sim_maps
This shows some more implementation details like producing the text embedding, computing similarities, and tracking the residuals when finding the textual basis vectors with TextSpan. The joint decomposition applies patch decomposition for each basis text and aggregates the results.