Improving Zero-Shot Generalization for CLIP with Synthesized Prompts
Authors: Zhengbo Wang, Jian Liang, Ran He, Nan Xu, Zilei Wang, Tieniu Tan
Abstract: With the growing interest in pretrained vision-language models like CLIP, recent research has focused on adapting these models to downstream tasks. Despite achieving promising results, most existing methods require labeled data for all classes, which may not hold in real-world applications due to the long tail and Zipf’s law. For example, some classes may lack labeled data entirely, such as emerging concepts. To address this problem, we propose a plug-and-play generative approach called SyntHesIzed Prompts (SHIP) to improve existing fine-tuning methods. Specifically, we follow variational autoencoders to introduce a generator that reconstructs the visual features by inputting the synthesized prompts and the corresponding class names to the textual encoder of CLIP. In this manner, we easily obtain the synthesized features for the remaining label-only classes. Thereafter, we fine-tune CLIP with off-the-shelf methods by combining labeled and synthesized features. Extensive experiments on base-to-new generalization, cross-dataset transfer learning, and generalized zero-shot learning demonstrate the superiority of our approach. The code is available at github. com/mrflogs/SHIP.
What, Why and How
Here is a summary of the key points in this paper:
What:
- The paper proposes a new method called Synthesized Prompts (SHIP) to improve zero-shot generalization for CLIP.
Why:
-
Existing methods for fine-tuning CLIP require labeled data for all classes. This is impractical for real-world applications where some classes lack labeled data (long-tail distribution).
-
Methods like prompt tuning significantly degrade performance on classes without labeled data.
-
The goal is to develop an approach that works for both classes with and without labeled data, while maintaining data efficiency.
How:
-
Uses a VAE as a generator to reconstruct visual features from synthesized prompts and class names fed into CLIP’s text encoder.
-
Generator is trained on limited labeled data from base classes. Can then generate features for new classes using just their names.
-
Combines real features of base classes and synthesized features of new classes. Fine-tunes CLIP on this using existing methods like prompt tuning.
-
Experiments show SHIP boosts performance on new classes without hurting base classes. Improves few-shot generalization, transfer learning, and generalized zero-shot learning.
In summary, SHIP is a generative approach to synthesize features for new classes without labeled data. It augments existing methods to improve CLIP’s generalization, especially under class distribution skew.
Main Contributions
Here are the main contributions of this paper:
-
Proposes a new generative method called SHIP to synthesize features for classes without labeled data by utilizing CLIP’s pretrained text encoder.
-
Integrates synthesized data seamlessly with existing methods like prompt tuning or adapter tuning to improve their generalization.
-
Achieves state-of-the-art performance on base-to-new generalization, cross-dataset transfer, and generalized zero-shot learning benchmarks.
-
Demonstrates improved performance on tail classes without hurting head classes, reducing the skew.
-
Provides a simple and effective approach to handle new/emerging classes with only class names.
-
Analyzes different generative models and prompt formulations to justify the design choices.
-
Shows that learning prompts is better than features for generalization.
-
Provides interpretability of learned prompts to gain insights.
In summary, the key contributions are proposing SHIP for synthesizing data, integrating it with existing methods, and achieving excellent few-shot generalization on multiple benchmarks, especially for tail classes.
Method Section
Here is a summary of the method section from the paper:
-
The goal is to improve fine-tuning of CLIP for both base classes (with labeled data) and new classes (no labeled data).
-
Proposes a generative model called SHIP to synthesize features for new classes.
-
Uses a VAE as the framework instead of GANs since VAEs are easier to train with limited data.
-
Generator takes a latent code z and class name c as input, and reconstructs the corresponding visual feature x.
-
Generator is designed to output synthetic prompts rather than features directly.
-
Prompts consist of global context vectors (initialized randomly) and a local bias vector computed from z via an MLP.
-
Feed prompts + class name into CLIP’s frozen text encoder to obtain reconstructed feature.
-
Only train encoder, MLP, and global prompts - keep CLIP fixed.
-
After training, sample z from prior and generate prompts for new classes using their names. Get features from text encoder.
-
Combine real features of base classes and synthetic features of new classes.
-
Use existing methods like prompt tuning to fine-tune CLIP on this combined dataset.
In summary, the key ideas are using a VAE framework, leveraging CLIP’s text encoder via prompt synthesis, and integrating synthesized data of new classes with real data of base classes for fine-tuning CLIP.
High-Level Pseudo Code
Here is the high-level pseudo code for the SHIP method:
# Train SHIP generator
for x, c in base_dataset:
z = encoder(x) # encode real feature
p = generate_prompts(z) # get prompts
x_recon = CLIP_text_encoder(p, c) # reconstruct feature
loss = VAE_loss(x, x_recon) # VAE reconstruction loss
update(encoder, prompts, MLP)
# Generate synthetic dataset
new_features = []
for c in new_classes:
z = sample_prior()
p = generate_prompts(z)
x = CLIP_text_encoder(p, c)
new_features.append((x, c))
# Fine-tune CLIP
combined_data = base_dataset + new_features
fine_tune_CLIP(combined_data) # using any existing method
The key steps are:
- Train SHIP generator on base classes
- Use it to generate synthetic features for new classes
- Combine real and synthetic datasets
- Fine-tune CLIP on combined data
This allows improving CLIP’s generalization to new classes in a data-efficient way.
Detailed Pseudo Code
Here is more detailed pseudo code to implement the SHIP method:
# SHIP Generator
class Encoder(nn.Module):
# Encoder network
def __init__(self):
# Encoder architecture
def forward(self, x):
# Returns: z
class MLP(nn.Module):
# MLP network
def __init__(self):
# MLP architecture
def forward(self, z):
# Returns: r
def generate_prompts(z, p):
r = MLP(z) # Get local bias
p = p + r # Add to global prompts
return p
# Train
encoder = Encoder()
mlp = MLP()
prompts = nn.Parameter(torch.randn(L, D)) # Initialize prompts
for x, c in base_dataset:
z = encoder(x)
p = generate_prompts(z, prompts)
x_recon = CLIP_text_encoder(p, c)
loss = VAE_loss(x, x_recon)
update(encoder, mlp, prompts)
# Generate new class features
for c in new_classes:
z = torch.randn(D) # Sample latent code
p = generate_prompts(z, prompts)
x = CLIP_text_encoder(p, c)
new_features.append((x, c))
# Fine-tune
combined_data = base_dataset + new_features
fine_tune_CLIP(combined_data)
where VAE_loss
can be MSE reconstruction loss + KL divergence prior loss.