Contextual Prompt Learning for Vision-Language Understanding
Authors: Koustava Goswami, Srikrishna Karanam, Joseph K. J, Prateksha Udhayanan, Balaji Vasan Srinivasan
Abstract: Recent advances in multimodal learning has resulted in powerful vision-language models, whose representations are generalizable across a variety of downstream tasks. Recently, their generalizability has been further extended by incorporating trainable prompts, borrowed from the natural language processing literature. While such prompt learning techniques have shown impressive results, we identify that these prompts are trained based on global image features which limits itself in two aspects: First, by using global features, these prompts could be focusing less on the discriminative foreground image, resulting in poor generalization to various out-of-distribution test cases. Second, existing work weights all prompts equally whereas our intuition is that these prompts are more specific to the type of the image. We address these issues with as part of our proposed Contextual Prompt Learning (CoPL) framework, capable of aligning the prompts to the localized features of the image.
What, Why and How
Here is a summary of the key points from the paper:
What:
-
The paper presents a new method called Contextual Prompt Learning (CoPL) for vision-language understanding.
-
It improves upon prior work like CoOp and CoCoOp by aligning prompt vectors to local image features and learning to weight the prompts based on relevance to the local context.
Why:
-
Prior methods use global image features which focus less on discriminative regions and weight all prompts equally.
-
Aligning prompts to local features allows capturing salient visual concepts. Learning prompt relevance weights makes them more semantically meaningful.
How:
-
Uses patch embeddings from vision transformer as local features. Aligns them to prompts using content-based attention.
-
Learns a context vector per patch that weights prompts based on semantic relevance.
-
Combines weighted prompts into image-specific prompt vectors for classification.
-
Extensive experiments on 11 datasets show improved generalization over baselines, especially for unseen classes and few-shot scenarios.
In summary, CoPL improves vision-language models by learning prompt vectors that are aligned to local image context and weighted based on semantic relevance. This provides better generalization as shown through strong experimental results.
Main Contributions
The main contributions of this paper are:
-
Identifies two key issues with existing prompt learning methods: equal weighting of all prompts and no contextual alignment between image features and prompts.
-
Proposes Contextual Prompt Learning (CoPL) to address these issues.
-
Learns to align prompts to local image features using content-based attention. This captures salient concepts.
-
Learns a context vector per image region that weights prompts based on semantic relevance. This makes prompts more meaningful.
-
Achieves state-of-the-art results on 11 datasets, outperforming prior prompt learning methods CoOp and CoCoOp.
-
Demonstrates improved few-shot learning and generalization to unseen classes due to contextual prompt learning.
-
Provides extensive experiments analyzing the impact of local vs global features, seen vs unseen classes, and one-shot training scenarios.
-
Shows CoPL’s efficacy on diverse tasks like classification, fine-grained recognition, texture recognition etc.
In summary, the key contributions are proposing contextual prompt learning for better alignment with image semantics, and experimentally demonstrating its effectiveness for vision-language understanding.
Method Section
Here is a summary of the method section from the paper:
-
CoPL builds on top of CLIP which has separate text and image encoders aligned using contrastive learning.
-
CoOp introduced learning continuous prompt vectors via backpropagation with CLIP weights fixed.
-
CoCoOp improved on CoOp by conditioning prompts on global image features using a meta-net.
-
CoPL identifies issues with global features and equal prompt weighting in CoCoOp.
-
Uses local features from CLIP’s patch embeddings instead of global features.
-
Learns lightweight network to get per-patch conditional vectors.
-
Computes attention between each patch vector and prompt vectors to get alignment weights.
-
Uses alignment weights to get a context vector per patch that is a weighted combination of prompt vectors.
-
Updates prompt vectors using summation of the context vectors.
-
Final prompts are image-specific, aligned to local context, weighted based on relevance.
-
For classification, combines updated prompts with class embedding and computes similarity with image embedding.
In summary, CoPL aligns prompts to local image semantics by using patch features, content-based attention, and weighted context vectors to update prompts. This gives improved generalization.
High-Level Pseudo Code
Here is the high-level pseudo code for the method presented in the paper:
# CoPL: Contextual Prompt Learning
# Input:
# image I
# learned prompt vectors v1, v2, ..., vM
# patch features s1, s2, ..., sP from image I
# Output:
# updated image-specific prompt vectors
# Get patch conditional vectors
for p in 1...P:
sp = lightweight_network(sp)
# Align patches to prompts
for p in 1...P:
# Content-based attention
alignment_weights[p] = softmax(score(sp, vi))
context_vector[p] = sum(alignment_weights[p] * vi)
# Update prompt vectors
for m in 1...M:
vm(I) = vm + sum(context_vectors)
# Classification
label_probs = softmax(sim(image_embedding, [v1(I), ..., vM(I), class_embedding]))
This summarizes the key steps:
- Get conditional vectors for each image patch
- Align patches to prompts using content-based attention
- Compute weighted context vectors for each patch
- Update prompt vectors using context vectors
- Use updated, image-specific prompts for classification
Detailed Pseudo Code
Here is a more detailed pseudo code for implementing the Contextual Prompt Learning method:
# Hyperparameters
num_prompts M
num_patches P
patch_dim D
# Input
image I
prompt_vectors v1, v2, ..., vM in R^D
patch_features s1, s2, ..., sP in R^D
# Lightweight network
def lightweight_network(patch):
# E.g. 1-layer MLP
return relu(linear(patch))
# Content function for alignment
def content_fn(patch, prompt):
return tanh(W[patch; prompt])
# Align patches to prompts
alignment_weights = []
for p in 1...P:
alignment_weight_p = []
patch_cond_vector = lightweight_network(sp)
for i in 1...M:
score = content_fn(patch_cond_vector, vi)
alignment_weight_p.append(exp(score) / sum(exp(scores)))
alignment_weights.append(alignment_weight_p)
# Context vectors
context_vectors = []
for p in 1...P:
context_vector_p = sum(alignment_weights[p] * v1, ..., vM)
context_vectors.append(context_vector_p)
# Update prompt vectors
updated_prompts = []
for m in 1...M:
vm_updated = vm + sum(context_vectors)
updated_prompts.append(vm_updated)
# Classification
image_embedding = image_encoder(I)
class_embedding = text_encoder(class)
label_probs = softmax(cosine_sim(image_embedding, [updated_prompts, class_embedding))
This provides more implementation details like using an MLP for the lightweight network, tanh content function, and accumulating the context vectors.