Approximated Prompt Tuning for Vision-Language Pre-trained Models

Authors: Qiong Wu, Shubin Huang, Yiyi Zhou, Pingyang Dai, Annan Shu, Guannan Jiang, Rongrong Ji

Abstract: Prompt tuning is a parameter-efficient way to deploy large-scale pre-trained models to downstream tasks by adding task-specific tokens. In terms of vision-language pre-trained (VLP) models, prompt tuning often requires a large number of learnable tokens to bridge the gap between the pre-training and downstream tasks, which greatly exacerbates the already high computational overhead. In this paper, we revisit the principle of prompt tuning for Transformer-based VLP models, and reveal that the impact of soft prompt tokens can be actually approximated via independent information diffusion steps, thereby avoiding the expensive global attention modeling and reducing the computational complexity to a large extent. Based on this finding, we propose a novel Approximated Prompt Tuning (APT) approach towards efficient VL transfer learning. To validate APT, we apply it to two representative VLP models, namely ViLT and METER, and conduct extensive experiments on a bunch of downstream tasks. Meanwhile, the generalization of APT is also validated on CLIP for image classification and StableDiffusion for text-to-image generation. The experimental results not only show the superior performance gains and computation efficiency of APT against the conventional prompt tuning methods, e.g., +7.01% accuracy and -82.30% additional computation overhead on METER, but also confirm its merits over other parameter-efficient transfer learning approaches.

What, Why and How

Here is a summary of the key points in this paper:

What:

  • This paper proposes a novel Approximated Prompt Tuning (APT) method for efficiently adapting vision-language pre-trained (VLP) models like ViLT and METER to downstream tasks.

Why:

  • Conventional prompt tuning methods for VLP models require a large number of trainable tokens, which greatly increases the computational overhead due to self-attention.
  • The information diffusion from prompt tokens to input sequence is inefficient in existing prompt tuning methods.

How:

  • APT separates the prompt tokens from expensive global self-attention.
  • It approximates the impact of prompts via independent low-rank aggregation steps rather than global attention.
  • This allows more efficient information diffusion from prompts to input while reducing computational complexity.
  • Experiments show APT outperforms prompt tuning methods in performance and efficiency.
  • APT also achieves better results than other parameter-efficient transfer learning approaches on VLP models.

In summary, this paper proposes a more efficient prompt tuning approach called APT to address the issues of high computation and inefficient adaption in existing prompt tuning methods for VLP models. APT approximates the impact of prompt tokens independently to enable more effective and efficient adaption.

Main Contributions

Here are the main contributions of this paper:

  • Identifies the key challenges of applying prompt tuning to vision-language pre-trained (VLP) models, including excessive computation overhead and inefficient prompt tuning.

  • Proposes a novel Approximated Prompt Tuning (APT) method that approximates the influence of prompt tokens via independent low-rank aggregation steps rather than expensive global attention. This allows more efficient information diffusion for adaption.

  • Validates APT on two representative VLP models (ViLT and METER) on three vision-language benchmarks. Shows superior performance and efficiency gains over conventional prompt tuning methods.

  • Demonstrates better performance of APT over other parameter-efficient transfer learning approaches on the VLP models.

  • Generalizes APT to CLIP for image classification and StableDiffusion for text-to-image generation, confirming its effectiveness beyond the VLP models initially tested.

In summary, the key contributions are: 1) identifying inefficiencies in existing prompt tuning for VLP models; 2) proposing a more efficient approximation method called APT; 3) demonstrating performance improvements over prompt tuning and other methods on VLP models; and 4) generalizing APT to other models like CLIP and StableDiffusion.

Method Section

Here is a summary of the method section:

The key idea of Approximated Prompt Tuning (APT) is to separate the impacts of prompt tokens from expensive global self-attention.

Specifically, APT first approximates the information aggregation from prompt tokens via:

ΔX’ = σ(X(P’W1W2 + P’)T)P’

Where P’ is the prompt tokens, W1 and W2 are low-rank matrices to transform prompts from V to K space, and σ is softmax function.

This allows prompt token impacts to be approximated independently without global attention.

To merge the original self-attention output and prompt token information, APT uses:

ΔX = α ⋅ ψ(X(P’W1W2 + P’)T)P’

Where ψ is ReLU activation and α is the max of the prompt attention weights.

Finally, the overall APT update is:

X’ = es ⋅ (SA(X) + α ⋅ ψ(X(P’W1W2 + P’)T)P’)

Where s is a learned scale and SA(X) is standard self-attention on inputs.

In summary, APT approximates the impact of prompt tokens via independent low-rank aggregation and activation, avoiding expensive global self-attention modeling while allowing more effective information diffusion for adaption.

High-Level Pseudo Code

Here is the high-level pseudo code for the proposed Approximated Prompt Tuning (APT) method:

# Input: 
# X: input sequence features
# P: prompt tokens  
 
# Self-attention module
def SA(X):
  # Original self-attention on X
  return A_X
 
# APT module  
def APT(X, P):
    
  # Low-rank prompt aggregation
  P' = softmax(X @ (P' @ W1 @ W2 + P')) @ P'  
 
  # Get max weight as scale
  α = max(P' @ W1 @ W2 + P')   
  
  # Information diffusion  
  ΔX = α * relu(X @ (P' @ W1 @ W2 + P')^T) @ P'
 
  return ΔX
 
# Full APT update
def forward(X, P):
 
  # Self-attention outputs
  A_X = SA(X)  
  
  # APT outputs   
  ΔX = APT(X, P)
  
  # Learnable scale
  s = learnable_scalar
  
  # Final outputs
  X' = exp(s) * (A_X + ΔX)
 
  return X'

In summary, the key steps are:

  1. Perform standard self-attention on inputs
  2. Approximate prompt impacts via APT module
  3. Merge self-attention and APT outputs with learnable scale
  4. Return final adapted features

Detailed Pseudo Code

Here is a more detailed pseudo code implementation of the Approximated Prompt Tuning (APT) method:

import torch
import torch.nn as nn
 
# Input embeddings
X = torch.randn(n, d) # Input sequence
P = torch.randn(p, d) # Prompt tokens
 
# Hyperparameters
r = 4 # Low-rank size
 
# Learnable parameters
W1 = nn.Parameter(torch.randn(d, r)) 
W2 = nn.Parameter(torch.randn(r, d))
s = nn.Parameter(torch.tensor(0.0))
 
# Self-attention module
class SA(nn.Module):
  def __init__(self):
    super().__init__()
    # Self-attention layers, projections, etc
 
  def forward(self, X):
    A_X = # Original self-attention(X) 
    return A_X
 
# APT module
class APT(nn.Module):   
  def __init__(self):
    super().__init__()
 
  def forward(self, X, P):
    # Low-rank prompt aggregation 
    P' = nn.Softmax(X @ (P @ W1 @ W2 + P).T) @ P  
 
    # Get max weight as scale
    α = torch.max(P @ W1 @ W2 + P)
 
    # Information diffusion
    ΔX = α * nn.ReLU(X @ (P' @ W1 @ W2 + P').T) @ P'
 
    return ΔX
 
# Full model
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.sa = SA() 
    self.apt = APT()
  
  def forward(self, X, P):
    A_X = self.sa(X)
    ΔX = self.apt(X, P)
    return torch.exp(s) * (A_X + ΔX) 
 
# Initialize and optimize model
model = Model() 
# Optimizer, loss, etc

In summary, the key aspects are:

  • Instantiating the SA and APT modules
  • Passing inputs X, P through the modules
  • Merging the outputs with a learnable scalar s
  • Initializing, optimizing, and training the full model

Let me know if you need any part of the implementation explained further!