Seeing in Words: Learning to Classify through Language Bottlenecks

Authors: Khalid Saifullah, Yuxin Wen, Jonas Geiping, Micah Goldblum, Tom Goldstein

Abstract: Neural networks for computer vision extract uninterpretable features despite achieving high accuracy on benchmarks. In contrast, humans can explain their predictions using succinct and intuitive descriptions. To incorporate explainability into neural networks, we train a vision model whose feature representations are text. We show that such a model can effectively classify ImageNet images, and we discuss the challenges we encountered when training it.

What, Why and How

Here is a summary of the key points from the paper:

What:

  • The paper trains an image classification model that uses text as an intermediate representation instead of continuous features.
  • It inserts a “language bottleneck” into a standard image classification pipeline by generating descriptive text tokens from the image and using those tokens to predict the class.

Why:

  • To make image classification models more interpretable by forcing them to generate explanations through text.
  • Text representations can provide insights into how models learn and reason about image contents.

How:

  • Uses a pretrained image-text model (BLIP) to generate text tokens from images. The tokens are fed into a classification head.
  • Explores techniques like similarity loss, language model likelihood loss, and no-repetition sampling to produce better text.
  • Evaluates on ImageNet classification and corrupted ImageNet datasets. Adding text substantially improves over using just BLIP captions.
  • Shows the model can classify images through text tokens only, suggesting language can be a “universal interface” between vision and text.

In summary, the paper inserts a text description bottleneck into an image classifier to make it more interpretable and shows this approach can work well compared to just using pretrained captions. The key ideas are generating text to explain predictions and evaluating if text alone can classify images.

Main Contributions

Here are the key contributions of this paper:

  • Proposes a novel image classification model architecture that uses text as an intermediate representation between the visual encoder and classification head. This “language bottleneck” forces the model to generate textual explanations for its predictions.

  • Demonstrates that an image classifier can achieve decent accuracy on ImageNet even when relying solely on text features generated from the images, without using continuous vision features. This suggests text can serve as a “universal interface” between vision and language.

  • Introduces techniques like token similarity loss, language model likelihood loss, and no-repetition sampling to improve the quality and diversity of the generated textual descriptions.

  • Evaluates the approach comprehensively on ImageNet and corrupted ImageNet benchmarks. Shows substantially improved accuracy over just using captions from a pretrained image-text model like BLIP.

  • Provides analysis and examples illustrating how the generated text can give insights into the model’s reasoning and interpretation of image contents.

  • Discusses implications and future directions for using language bottlenecks to improve model interpretability and transparency.

In summary, the key contributions are proposing the language bottleneck architecture, showing text alone can classify images, developing techniques to improve generated text, extensive benchmarking, and analysis about using text to interpret model predictions. The paper offers a new perspective on bringing interpretability to vision models.

Method Section

Here is a summary of the method section:

  • The goal is to create an image classifier that uses text as an intermediate representation rather than continuous features. This text should describe the image contents in detail.

  • They build on top of BLIP, a pretrained image-to-text model that generates image captions.

  • BLIP has a visual transformer encoder that converts images to embeddings, and a language model decoder that generates text tokens using the image embeddings.

  • To make the model differentiable, they feed trainable soft prompts (rather than sampling hard tokens) into the language model.

  • The language model outputs logits over the vocabulary for each prompt. They softmax these and multiply with embeddings to get token vectors.

  • They mean pool the token vectors to get a single aggregated text representation of the image.

  • This pooled vector is fed into a linear classifier to predict the image class.

  • Only the soft prompt and linear classifier parameters are trained. During validation, they use argmax sampling to get hard tokens.

  • They introduce techniques like token similarity loss, language model likelihood loss, and no-repetition sampling to improve the quality of the generated text.

In summary, they insert soft textual prompts between a visual encoder and classifier, and optimize the prompts and classifier in an end-to-end fashion to learn textual representations of images that predict the class. Additional losses are used to improve the diversity and readability of the generated text descriptions.

High-Level Pseudo Code

Here is the high-level pseudo code for the paper:

# Load pretrained BLIP model
blip = BLIP() 
 
# Freeze BLIP weights
blip.freeze_weights()
 
# Initialize soft prompts and linear classifier 
prompts = torch.randn(n, d) # n prompts, d dimensions
classifier = Linear(d, num_classes)
 
# Training loop
for images, labels in dataloader:
 
  # Get image embeddings
  image_embeddings = blip.encode_image(images)  
 
  # Get token embeddings from soft prompts
  token_embs = blip.decode_text(prompts)
 
  # Pool token embeddings 
  text_features = torch.mean(token_embs, dim=1)
 
  # Get predictions
  preds = classifier(text_features)
 
  # Compute classification loss
  loss = cross_entropy(preds, labels)  
 
  # Additional losses (token similarity, LLM)
 
  # Update prompt and classifier weights 
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

The key steps are:

  1. Initialize the soft prompts and classifier head
  2. Pass images through BLIP to get image embeddings
  3. Decode soft prompts using BLIP to get token embeddings
  4. Pool token embeddings to get aggregated text features
  5. Feed text features to classifier to make predictions
  6. Compute classification loss and additional losses
  7. Update prompts and classifier weights through backprop

The prompts are optimized to generate textual representations that predict the correct image class.

Detailed Pseudo Code

Here is a more detailed pseudo code for implementing the paper:

# Hyperparameters
num_prompts = 16 
prompt_dim = 512
num_classes = 1000
 
# Initialize models
blip = BLIP()
blip.freeze_weights()
prompts = nn.Embedding(num_prompts, prompt_dim) 
classifier = nn.Linear(prompt_dim, num_classes)
 
# Initialize optimizers 
prompt_opt = Adam(prompts.parameters(), lr=0.1)
class_opt = Adam(classifier.parameters(), lr=0.005) 
 
# Load pretrained BLIP tokenizer
tokenizer = BLIPTokenizer()
 
for images, labels in dataloader:
    
  # Get image embeddings from BLIP 
  img_embeds = blip.visual_encoder(images)
  
  # Generate soft prompts
  soft_prompts = prompts(torch.arange(num_prompts))
  
  # Decode prompts using BLIP
  token_probs = blip.text_decoder(soft_prompts, img_embeds)
 
  # Sample hard tokens (inference only)
  hard_tokens = tokenizer.sample(token_probs) 
  
  # Get token embeddings (training) 
  token_embeds = blip.text_encoder(token_probs)
 
  # Pool token embeddings
  text_features = token_embeds.mean(dim=1)
 
  # Get class predictions
  class_preds = classifier(text_features)
 
  # Classification loss
  cls_loss = cross_entropy(class_preds, labels)
 
  # Token similarity loss
  sim_loss = cosine_similarity(token_embeds).mean()
 
  # Language model likelihood 
  lm_loss = -blip.text_decoder(hard_tokens).mean()  
 
  # Backpropagate loss
  loss = cls_loss + sim_loss + lm_loss
  loss.backward()
 
  # Update model weights
  prompt_opt.step()
  class_opt.step()

The key additions are:

  • Initializing the prompt embeddings and classifier head
  • Getting image embeddings from BLIP’s visual encoder
  • Using BLIP’s text decoder to get token probabilities
  • Sampling hard tokens during inference
  • Getting token embeddings from token probabilities
  • Computing token similarity and language model losses
  • Separate optimizers for prompt and classifier

Let me know if any part of the pseudo code needs more clarification!