PromptStyler: Prompt-driven Style Generation for Source-free Domain Generalization

Authors: Junhyeong Cho, Gilhyun Nam, Sungyeon Kim, Hunmin Yang, Suha Kwak

Abstract: In a joint vision-language space, a text feature (e.g., from “a photo of a dog”) could effectively represent its relevant image features (e.g., from dog photos). Inspired by this, we propose PromptStyler which simulates various distribution shifts in the joint space by synthesizing diverse styles via prompts without using any images to deal with source-free domain generalization. Our method learns to generate a variety of style features (from “a S* style of a”) via learnable style word vectors for pseudo-words S*. To ensure that learned styles do not distort content information, we force style-content features (from “a S* style of a [class]”) to be located nearby their corresponding content features (from “[class]”) in the joint vision-language space. After learning style word vectors, we train a linear classifier using synthesized style-content features. PromptStyler achieves the state of the art on PACS, VLCS, OfficeHome and DomainNet, although it does not require any images and takes just ~30 minutes for training using a single GPU.

What, Why and How

Here is a summary of the key points from this paper:

What:

  • Proposes a new method called PromptStyler for source-free domain generalization. This means it can improve model generalization to new domains without requiring any source domain data.

Why:

  • Existing domain generalization methods require multiple source domains which can be costly/infeasible to collect. Leveraging pre-trained models can avoid this, but it’s unclear how to simulate domain shifts without source data.

How:

  • Leverages joint vision-language space of models like CLIP where text can represent image concepts.

  • Learns “style word vectors” that can be inserted into prompts to synthesize diverse “style features”. For example, “a S*_1 style of a dog”.

  • Maximizes style diversity and content consistency losses when learning vectors to get varied but meaningful styles.

  • Uses text encoder and synthesized style-content features to train linear classifier for target classes. At inference, uses image encoder to extract features and classify.

  • Achieves SOTA on domain generalization benchmarks without using any source images, taking just 30 mins of training. More efficient than CLIP.

In summary, it simulates domain shifts and trains a classifier fully in latent vision-language space using prompts and style word vectors, avoiding the need for source domain data.

Main Contributions

The main contributions of this paper are:

  1. Proposes a novel method to simulate domain shifts and improve generalization capability using only a pre-trained vision-language model, without requiring any source domain images. This enables source-free domain generalization.

  2. Introduces the idea of learning “style word vectors” that can be inserted into prompts to synthesize diverse style features in the joint vision-language space. Maximizes style diversity while preserving content consistency.

  3. Demonstrates state-of-the-art performance on multiple domain generalization benchmarks (PACS, VLCS, OfficeHome, DomainNet) without using any source domain data.

  4. Shows high efficiency - takes only ~30 minutes of training on a single GPU and has lower inference cost than CLIP.

  5. Opens up a new paradigm of training in the joint vision-language latent space using only class names and prompts, avoiding the need for collecting labeled data from multiple source domains.

In summary, the key contributions are proposing the idea of prompt-based style generation for source-free domain generalization, and showing its effectiveness empirically across benchmarks. The method is highly practical as it only requires class names, and efficient in training and inference.

Method Section

Here is a summary of the method section from the paper:

  • The key idea is to learn “style word vectors” that can be inserted into prompts to generate diverse styles in the joint vision-language space.

  • Three types of prompts are used:

    • Style prompt: “a S*_i style of a”
    • Content prompt: “[class]_m” (e.g. “dog”)
    • Style-content prompt: “a S*_i style of a [class]_m”
  • To learn a style word vector s_i:

    • Maximizes style diversity loss L_style to make the style feature from prompt “a S*_i style of a” orthogonal to existing style features.
    • Maximizes content consistency loss L_content to ensure the style-content feature from “a S*_i style of a [class]_m” is closest to the content feature from “[class]_m”.
    • Trains s_i to minimize the total loss L_prompt = L_style + L_content.
  • Learns K style word vectors {s_i}^K_{i=1} sequentially using the above method.

  • Uses the learned style vectors and class names to generate KN style-content features via the text encoder.

  • Trains a linear classifier on these features using ArcFace loss and class names as labels.

  • At inference, uses pre-trained image encoder to extract image features and classify with the trained linear model.

So in summary, it learns to synthesize diverse styles via prompts that preserve content, uses this to generate features for training classifier, and infers on real images using the image encoder. The entire training is done without real images.

High-Level Pseudo Code

Here is the high-level pseudo code for the method proposed in this paper:

# Requirement: pre-trained text encoder T, predefined class names
 
# Input: number of styles K, training iterations L 
 
# Output: KN style-content features
 
# Randomly initialize K style word vectors {s_i}^K_{i=1}
 
# Learn K style word vectors sequentially:
for i in range(K):
  
  # L training iterations for each word vector   
  for iteration in range(L):
 
    # Compute style diversity loss  
    L_style = style_diversity_loss(s_i, {s_j}^{i-1}_{j=1})  
 
    # Compute content consistency loss
    L_content = content_consistency_loss(s_i) 
 
    # Total prompt loss
    L_prompt = L_style + L_content 
 
    # Update s_i by gradient descent on L_prompt
 
  # Generate style-content features using 
  # learned style vectors and class names via T
 
# Train linear classifier on style-content features 
# using ArcFace loss and class names as labels

In summary, it randomly initializes style vectors, then learns them sequentially by optimizing diversity and consistency losses. The learned vectors are used with class names to synthesize features for training classifier.

Detailed Pseudo Code

Here is a more detailed pseudo code to implement the method proposed in this paper:

# Pre-trained text encoder T 
# Pre-defined class names [class]_m, m=1..N
 
# Hyperparameters
K = # number of styles 
L = # training iterations
D = # dimension of style vector
 
# Initialize style vectors
{s_i}^K_{i=1} = random_initialize({s_i}^K_{i=1}) 
 
# Learn style vectors sequentially 
for i in range(K):
 
  # Train s_i for L iterations
  for iter in range(L):
 
    # Style prompt 
    P^style_i = "a S*_i style of a"
 
    # Compute style features 
    {f^style_j}^{i-1}_{j=1} = T(P^style_j) 
 
    # Style diversity loss
    L_style = 1/(i-1) * Σ^{i-1}_{j=1} |f^style_i · f^style_j|  
 
    # Content prompts
    P^content_m = "[class]_m" 
 
    # Compute content features
    {f^content_n}^N_{n=1} = T(P^content_n)
 
    # Style-content prompts 
    P^stylecontent_im = "a S*_i style of a [class]_m"
 
    # Compute style-content features
    {f^stylecontent_im}^N_{m=1} = T(P^stylecontent_im) 
 
    # Content consistency loss 
    L_content = -Σ^N_{m=1} log (exp(f^stylecontent_im · f^content_m) / Σ^N_{n=1} exp(f^stylecontent_im · f^content_n))
 
    # Total prompt loss
    L_prompt = L_style + L_content
 
    # Update s_i using L_prompt and gradient descent
 
  # Generate KN style-content features 
  features = T(P^stylecontent_im) for all i,m  
 
# Train classifier on features using ArcFace loss
# Infer on images using pre-trained image encoder