Multimodal Neurons in Pretrained Text-Only Transformers

Authors: Sarah Schwettmann, Neil Chowdhury, Antonio Torralba

Abstract: Language models demonstrate remarkable capacity to generalize representations learned in one modality to downstream tasks in other modalities. Can we trace this ability to individual neurons? We study the case where a frozen text transformer is augmented with vision using a self-supervised visual encoder and a single linear projection learned on an image-to-text task. Outputs of the projection layer are not immediately decodable into language describing image content; instead, we find that translation between modalities occurs deeper within the transformer. We introduce a procedure for identifying “multimodal neurons” that convert visual representations into corresponding text, and decoding the concepts they inject into the model’s residual stream. In a series of experiments, we show that multimodal neurons operate on specific visual concepts across inputs, and have a systematic causal effect on image captioning.

What, Why and How

Here is a summary of the key points from the paper:

What:

  • The paper studies multimodal neurons that emerge when a pretrained text-only transformer (GPT-J) is augmented with vision using a self-supervised visual encoder (BEIT) and a learned linear projection layer.

  • They introduce a method to identify “multimodal neurons” in the transformer MLP layers that translate visual representations into corresponding text representations.

  • They decode the concepts that multimodal neurons inject into the model’s residual stream and show they operate on specific visual concepts (e.g. detecting horses).

Why:

  • To understand how text-only transformers generalize well to multimodal tasks when augmented with other modalities like vision.

  • To trace the model’s ability to align representations across modalities to individual neurons.

  • To investigate if multimodal neurons emerge without joint vision-language supervision, as previous work found them in CLIP which was trained on image-text pairs.

How:

  • Use gradient-based attribution to identify important neurons for image captioning.

  • Decode the text concepts associated with high-attribution neurons using the pretrained text decoder.

  • Evaluate correspondence between decoded concepts and image content using CLIPScore and BERTScore.

  • Analyze visual selectivity, effect of ablations, etc. to show identified neurons translate visual inputs into related text.

In summary, the key contribution is identifying and analyzing multimodal neurons that emerge in a text-only transformer when adapted to multimodal inputs, providing insights into how these models generalize across modalities. The neurons translate visual inputs into related text concepts and play a causal role in multimodal prediction.

Main Contributions

Here are the key contributions of this paper:

  1. Introduces the concept of “multimodal neurons” in transformer MLP layers that translate visual representations into corresponding text.

  2. Provides a method to identify important multimodal neurons using gradient-based attribution for an image captioning task.

  3. Decodes the text representations associated with multimodal neurons using the pretrained text decoder of the transformer.

  4. Evaluates the correspondence between decoded neuron text and image content using CLIPScore and BERTScore.

  5. Demonstrates that multimodal neurons emerge without joint vision-language supervision, in a text-only transformer adapted with a separate visual encoder.

  6. Analyzes the visual selectivity, interpretability, and causal impact on predictions of identified multimodal neurons through ablation studies and other experiments.

  7. Shows multimodal neurons are visually selective for image regions, activate on specific concepts consistently, and have a systematic effect on image captioning.

  8. Provides insights into how text transformers generalize to multimodal tasks by composing separate modalities, traced to individual neurons bridging vision and language.

  9. Discusses the broader potential for using text transformers as universal computation engines for non-linguistic sequential modeling tasks.

In summary, the key contributions are identifying and rigorously analyzing multimodal neurons that play an important role in adapting text transformers to multimodal inputs and predictions, while also providing insights into their generalization capabilities.

Method Section

Here is a summary of the method section from the paper:

The paper analyzes multimodal neurons in the LiMBeR model, which uses a linear layer to cast BEIT image features into the input space of GPT-J.

To identify important neurons, they compute attribution scores using gradients that approximate each neuron’s contribution to an image captioning task. Specifically, the attribution score sums the neuron’s output and its effect on the logit for the predicted next token.

To decode concepts associated with a neuron, they use the pretrained GPT-J text decoder applied to the neuron’s weights. This transforms the neuron’s text representation into a probability distribution over tokens.

They select neurons with the top attribution scores for each image, filter for interpretability (tokens are valid words), and evaluate correspondence to image content using CLIPScore and BERTScore.

For analyzing visual selectivity, they compute activations over image patches to simulate neuron receptive fields. They compare these to COCO object segmentations using Intersection over Union (IoU).

They also evaluate category-specific selectivity by measuring neuron activations on ImageNet validation images relative to their attribution scores on training images.

To test the causal effect, they ablate top neurons by zeroing activations, and measure the change in target token probability and caption similarity using BERTScore.

In summary, the key aspects of the method are computing attributions to identify important multimodal neurons, decoding their associated text, and evaluating visual and causal effects through multiple experiments. This enables analyzing the emergence and role of multimodal neurons in adapting text-only transformers.

High-Level Pseudo Code

Here is the high-level pseudo code for the method described in the paper:

# Load pretrained models
text_model = GPT-J() 
image_model = BEIT()
projection_layer = LinearLayer() 
 
# Get image features
image_patches = image_model(image)
 
# Project to text model input space 
text_inputs = projection_layer(image_patches)
 
# Generate caption
caption = text_model(text_inputs) 
 
# Get target token for attribution
target_token = get_target_token(caption)
 
# Compute attributions 
attributions = compute_attributions(text_model, target_token, text_inputs)
 
# Extract high-scoring neurons
top_neurons = get_top_k_neurons(attributions, k=100)
 
# Filter for interpretability
top_neurons = filter_interpretable(top_neurons)
 
# Decode text concepts 
text_concepts = decode_neurons(top_neurons, text_model)
 
# Evaluate correspondence
image_score = score_concepts(text_concepts, image) 
 
# Evaluate visual selectivity
receptive_fields = get_receptive_fields(top_neurons, image_patches)
iou = compute_iou(receptive_fields, annotations)
 
# Evaluate category selectivity
category_activations = get_activations(top_neurons, category_images)
 
# Ablation study
ablated_prob = ablate_and_predict(top_neurons, text_model, text_inputs)
ablated_caption = ablate_and_generate(top_neurons, text_model, text_inputs)

This covers the overall pipeline of identifying multimodal neurons, decoding and evaluating them, and analyzing their visual and causal effects as described in the paper. The key steps are computing attributions, decoding text concepts, evaluating correspondence, and analyzing selectivity and impacts through ablations.

Detailed Pseudo Code

Here is some more detailed pseudo code to implement the key steps:

# Compute attributions
def compute_attributions(model, target_token, inputs):
  # Forward pass 
  outputs = model(inputs)
  
  # Get target token logit
  logit = outputs[target_token_index] 
  
  # Backward pass
  logit.backward()
  
  # Get gradients w.r.t. inputs
  gradients = inputs.grad
  
  # Element-wise multiplication 
  attributions = inputs * gradients
  
  return attributions
 
# Extract top k neurons
def get_top_k_neurons(attributions, k=100):
 
  # Reshape attn into layers and units
  attributions = attributions.reshape(num_layers, num_units)
 
  # Take max across layers
  max_attributions = attributions.max(dim=0)
 
  # Top k units
  top_k = torch.topk(max_attributions, k=k)
 
  return top_k[1] # indices
 
# Filter interpretable units
def filter_interpretable(units):
  
  interpretations = decode_neurons(units)
  
  # Count valid words
  word_counts = [count_valid_words(interp) for interp in interpretations]
  
  # Threshold
  threshold = 7 
  mask = [count >= threshold for count in word_counts]
 
  return units[mask]
 
# Ablation study
def ablate_and_predict(units, model, inputs):
 
  # Zero out activations
  acts = model.activations(inputs)
  acts[units] = 0
  
  # Rerun forward pass
  outputs = model(inputs, activations=acts) 
  
  # Probabilities
  probs = softmax(outputs)
  
  return probs
 
# Generate new caption
def ablate_and_generate(units, model, inputs):
 
  # Ablate
  probs = ablate_and_predict(units, model, inputs)
 
  # Sample new caption
  caption = sample_caption(probs)
 
  return caption

This provides more implementation details for the key analysis steps like computing attributions, extracting and filtering important neurons, and performing ablation studies to analyze the causal impacts. The main data structures are the input embeddings, model activations, and caption probabilities which are operated on to implement the method.