Augmenting CLIP with Improved Visio-Linguistic Reasoning

Authors: Samyadeep Basu, Maziar Sanjabi, Daniela Massiceti, Shell Xu Hu, Soheil Feizi

Abstract: Image-text contrastive models such as CLIP are useful for a variety of downstream applications including zero-shot classification, image-text retrieval and transfer learning. However, these contrastively trained vision-language models often fail on compositional visio-linguistic tasks such as Winoground with performance equivalent to random chance. In our paper, we address this issue and propose a sample-efficient light-weight method called SDS-CLIP to improve the compositional visio-linguistic reasoning capabilities of CLIP. The core idea of our method is to use differentiable image parameterizations to fine-tune CLIP with a distillation objective from large text-to-image generative models such as Stable-Diffusion which are relatively good at visio-linguistic reasoning tasks. On the challenging Winoground compositional reasoning benchmark, our method improves the absolute visio-linguistic performance of different CLIP models by up to 7%, while on the ARO dataset, our method improves the visio-linguistic performance by upto 3%. As a byproduct of inducing visio-linguistic reasoning into CLIP, we also find that the zero-shot performance improves marginally on a variety of downstream datasets. Our method reinforces that carefully designed distillation objectives from generative models can be leveraged to extend existing contrastive image-text models with improved visio-linguistic reasoning capabilities.

What, Why and How

Here is a summary of the key points from the paper:

What:

  • The paper proposes a method called VCLIP to improve the compositional visio-linguistic reasoning abilities of CLIP models.

Why:

  • CLIP models fail on compositional visio-linguistic reasoning tasks like Winoground, performing close to random chance.
  • In contrast, text-to-image models like Stable Diffusion perform much better on these tasks.
  • The denoising diffusion objective in text-to-image models induces better correspondences between images and text.

How:

  • The paper distills knowledge from Stable Diffusion into CLIP using a score-distillation sampling (SDS) objective.
  • This regularizes the contrastive loss during CLIP fine-tuning with a small paired image-text dataset.
  • Only the LayerNorm parameters of CLIP are tuned, making the method very lightweight.
  • Results show improved performance on Winoground (1.5-7% gain) without harming downstream zero-shot performance.

In summary, the paper shows that distilling knowledge from text-to-image models can improve the compositional reasoning abilities of contrastive vision-language models like CLIP in a sample and parameter efficient way.

Main Contributions

Here are the main contributions of this paper:

  • Highlights the strong visio-linguistic reasoning capabilities of text-to-image generative models like Stable Diffusion, attributed to the denoising diffusion objective.

  • Introduces a novel fine-tuning method called VCLIP to distill knowledge from Stable Diffusion into CLIP to improve its compositional reasoning abilities.

  • Shows VCLIP is highly sample and parameter efficient, using only 118k image-text pairs from MS-COCO and tuning just the LayerNorm parameters.

  • Demonstrates improved performance on challenging visio-linguistic reasoning benchmarks like Winoground (gains of 1.5-7%) without harming downstream zero-shot performance.

  • Provides evidence that distillation from generative models can help improve contrastive vision-language models like CLIP.

  • Highlights the potential to augment existing models with improved reasoning in a lightweight post-hoc tuning step.

In summary, the key contributions are introducing the VCLIP distillation method, showing its sample/parameter efficiency, demonstrating improved reasoning abilities on benchmarks, and highlighting the potential of distillation to enhance contrastive vision-language models.

Method Section

Here is a summary of the method section from the paper:

  • The core idea is to regularize the contrastive loss in CLIP with the denoising diffusion score from Stable Diffusion during fine-tuning.

  • This is implemented using score-distillation sampling (SDS), where the CLIP image embeddings are mapped to the input space of the Stable Diffusion UNet.

  • Specifically, the CLIP image encoder output is passed through a learned linear mapping into the UNet’s input space.

  • The diffusion loss between the mapped CLIP embeddings and the UNet’s predicted noise is computed as an additional regularization term.

  • This SDS loss encourages CLIP to learn embeddings aligned under both contrastive and denoising diffusion objectives.

  • Only the LayerNorm parameters of CLIP are tuned, keeping all other parameters fixed during fine-tuning.

  • The overall loss is a combination of the original CLIP contrastive loss and the SDS regularization loss.

  • Optimization is performed over the CLIP parameters, with the UNet frozen.

  • Fine-tuning is done on a small dataset of 118k COCO image-text pairs for just 5 epochs.

In summary, the method regularizes CLIP’s contrastive fine-tuning with a diffusion distillation loss using SDS. This aligns the CLIP embeddings to have strong visio-linguistic correspondences induced by the diffusion model.

High-Level Pseudo Code

Here is the high-level pseudo code for the method proposed in the paper:

# Hyperparameters
num_epochs = 5 
batch_size = 32
lambda = 0.001 # Regularization strength
 
# Load pre-trained CLIP model
clip = CLIPModel() 
 
# Freeze UNet diffusion model
unet = StableDiffusionUNet()
unet.freeze() 
 
for epoch in range(num_epochs):
  
  # Sample image-text batch 
  batch = sample_batch(num=batch_size)
  
  # Get CLIP embeddings
  image_embeds = clip.image_encoder(batch.images)
  text_embeds = clip.text_encoder(batch.captions)  
 
  # Get SDS loss
  sds_loss = get_sds_loss(image_embeds, batch, unet)
  
  # Compute CLIP contrastive loss 
  clip_loss = get_contrastive_loss(image_embeds, text_embeds)
 
  # Total loss
  loss = clip_loss + lambda * sds_loss
  
  # Update CLIP parameters
  clip.optimizer.zero_grad()
  loss.backward()
  clip.optimizer.step()
 
# Save fine-tuned CLIP model
save_model(clip)

This shows the high-level steps:

  1. Sample image-text batches
  2. Get CLIP embeddings
  3. Compute SDS and contrastive losses
  4. Update CLIP parameters via gradient descent on total loss

The key components are the SDS loss computed using the frozen UNet, and the composite loss function to distill into CLIP.

Detailed Pseudo Code

Here is a more detailed pseudo code to implement the method proposed in this paper:

# Load pre-trained CLIP 
clip = CLIP(config) # Initializes image encoder, text encoder
clip.load_state_dict(clip_pretrained_weights)
 
# Load pre-trained Stable Diffusion UNet
unet = StableDiffusionUNet(config) 
unet.load_state_dict(unet_pretrained_weights)
unet.freeze() # Freeze UNet parameters
 
# Linear layer to map CLIP embeddings to UNet input
linear_proj = nn.Linear(clip.embed_dim, unet.in_channels*unet.image_size**2) 
 
# Optimizer for fine-tuning CLIP 
optimizer = torch.optim.Adam(clip.parameters(), lr=5e-5) 
 
num_epochs = 5
dataset = COCODataset(size=118000) 
 
for epoch in range(num_epochs):
 
  for batch in dataset:
    
    # Get CLIP embeddings
    image_embeds = clip.image_encoder(batch.images)  
    text_embeds = clip.text_encoder(batch.captions)
 
    # Map CLIP embeddings to UNet input
    mapped_embeds = linear_proj(image_embeds)
 
    # Sample noise timestamps and Gaussian noise
    t = uniform_sample(num_timesteps) 
    noise = gaussian_sample(0, 1, mapped_embeds.shape)
 
    # Get UNet denoised image
    denoised = unet(mapped_embeds, t, batch.captions) 
    
    # Compute SDS loss
    sds_loss = ||denoised - noise||^2
 
    # Compute contrastive loss
    clip_loss = InfoNCE(image_embeds, text_embeds)  
 
    # Total loss
    loss = clip_loss + lambda * sds_loss
 
    # Optimization step
    optimizer.zero_grad() 
    loss.backward()
    optimizer.step()
 
  # Evaluation
  evaluate(clip, winoground_dataset)
 

This shows more implementation details like the linear projection layer, sampling timesteps and noise, passing embeddings through UNet, and the optimization loop.