E^2VPT: An Effective and Efficient Approach for Visual Prompt Tuning
Authors: Cheng Han, Qifan Wang, Yiming Cui, Zhiwen Cao, Wenguan Wang, Siyuan Qi, Dongfang Liu
Abstract: As the size of transformer-based models continues to grow, fine-tuning these large-scale pretrained vision models for new tasks has become increasingly parameter-intensive. Parameter-efficient learning has been developed to reduce the number of tunable parameters during fine-tuning. Although these methods show promising results, there is still a significant performance gap compared to full fine-tuning. To address this challenge, we propose an Effective and Efficient Visual Prompt Tuning (E^2VPT) approach for large-scale transformer-based model adaptation. Specifically, we introduce a set of learnable key-value prompts and visual prompts into self-attention and input layers, respectively, to improve the effectiveness of model fine-tuning. Moreover, we design a prompt pruning procedure to systematically prune low importance prompts while preserving model performance, which largely enhances the model’s efficiency. Empirical results demonstrate that our approach outperforms several state-of-the-art baselines on two benchmarks, with considerably low parameter usage (e.g., 0.32% of model parameters on VTAB-1k). Our code is available at E2VPT.
What, Why and How
Here is a summary of the key points from this paper:
\textbf{What:}
- The paper proposes EVPT, a novel framework for efficiently fine-tuning large pretrained vision transformers. It introduces learnable visual prompts and key-value prompts into the model architecture.
\textbf{Why:}
- To address limitations of prior work on parameter-efficient fine-tuning:
- Do not consider transformer architecture design, leading to performance gap with full fine-tuning
- Often still require tuning a large number of parameters
\textbf{How:}
- Visual prompts are added to the input sequence of transformer layers.
- Key-value prompts are integrated into the key and value matrices in self-attention.
- A pruning strategy removes redundant/unnecessary prompts to further improve efficiency.
- Experiments on image classification benchmarks show EVPT outperforms prior arts and achieves competitive results to full fine-tuning with far fewer parameters.
In summary, EVPT proposes an effective and efficient approach for visual prompt tuning of transformers, by designing architecture-aware prompts and pruning techniques to optimize both accuracy and parameter usage. The method demonstrates strong performance on benchmark datasets compared to prior parameter-efficient methods.
Main Contributions
Here are the main contributions of this paper:
-
Proposes EVPT, a new framework for efficient fine-tuning of vision transformers using visual and key-value prompts.
-
Introduces learnable visual prompts in the input sequence and key-value prompts in the self-attention modules of transformers. This architecture-aware design improves effectiveness.
-
Develops a pruning strategy to remove redundant prompts, further enhancing efficiency without sacrificing accuracy.
-
Conducts experiments on image classification benchmarks showing EVPT outperforms prior arts like VPT and achieves comparable or better accuracy than full fine-tuning while using far fewer parameters.
-
Provides extensive ablation studies analyzing the impact of different components in EVPT.
-
Visualizes learned embeddings to demonstrate the effectiveness of EVPT in separating classes.
-
Discusses the algorithmic generalization of EVPT to natural language processing tasks.
In summary, the key contributions are: 1) The novel EVPT framework for efficient transformer fine-tuning using architecture-aware prompts. 2) Pruning strategy to optimize efficiency. 3) Extensive experiments demonstrating effectiveness and efficiency compared to priors. 4) Ablation studies and visualizations providing insights.
Method Section
Here is a summary of the method section from the paper:
The paper proposes EVPT, a framework for effective and efficient visual prompt tuning of transformers.
It has two main components:
- Effective prompting:
-
Introduces visual prompts in the input sequence of each transformer layer. These learn task-specific embeddings.
-
Adds key-value prompts by concatenating small learnable matrices to the key and value matrices in self-attention. This helps capture new attention patterns.
-
Shares parameters of key-value prompts within each layer for efficiency.
- Efficient prompting:
-
Applies a pruning technique to remove redundant input prompts:
-
Token-wise pruning: Calculates importance scores for each prompt and removes those with low scores.
-
Segment-wise pruning: Divides remaining prompts into segments and prunes negative segments.
-
Rewinding: Retrains the soft-filtered prompts after pruning.
In summary, the method improves effectiveness via architecture-aware visual and key-value prompts, and efficiency by pruning unnecessary prompts before retraining. The overall framework enables effective and efficient tuning of transformers for new tasks.
High-Level Pseudo Code
Here is the high-level pseudo code for the EVPT method:
# E^2VPT Framework
# Input:
# T: Pretrained transformer backbone
# D: Dataset for new task
# Hyperparameters:
# pv_len: Length of visual prompts
# pkv_len: Length of key-value prompts
# prune_rate: Percentage of prompts to prune
# Initialize visual prompts pv and key-value prompts pk, vk
pv = init_prompts(pv_len)
pk, vk = init_prompts(pkv_len)
# Effective Prompting
for layer in T.layers:
layer.input = concat(pv, layer.input)
layer.attn = concat(layer.attn_k, pk)
layer.attn = concat(layer.attn_v, vk)
# Efficient Prompting
mask = prune(pv, prune_rate) # Prune prompts
pv_pruned = pv * mask # Mask prompts
# Fine-tune
for x, y in D:
logits = T(pv_pruned, x)
loss = cross_entropy(logits, y)
# Update prompts and classifier
update(pv_pruned, pk, vk, classifier)
# Evaluate on new task
accuracy = evaluate(T, D_test)
In summary, the key steps are:
- Initialize visual and key-value prompts
- Incorporate prompts into transformer architecture
- Prune redundant visual prompts
- Fine-tune prompts and classifier on new task
- Evaluate adapted model on test set
Detailed Pseudo Code
Here is a more detailed pseudo code for implementing the EVPT method:
# Initialize prompts
def init_prompts(length):
return nn.Parameter(torch.randn(length, dim))
pv = init_prompts(pv_len)
pk, vk = init_prompts(pkv_len)
# Add prompts to transformer
class E2VPT_Transformer(nn.Module):
def __init__(self, backbone):
# Backbone transformer
self.transformer = backbone
# Visual prompts
self.pv = pv
# Key-value prompts
self.pk = pk
self.vk = vk
def forward(self, x):
for layer in self.transformer.layers:
# Input
h = layer.input
h = torch.cat([self.pv, h], dim=1)
# Self-attention
k = layer.attn_k
v = layer.attn_v
k = torch.cat([k, self.pk], dim=2)
v = torch.cat([v, self.vk], dim=2)
h = layer(h, k, v) # self-attention + mlp
return h # output
# Pruning
def prune(prompt, rate):
score = torch.autograd.grad(loss, prompt)[0]
threshold = np.percentile(score, rate*100)
mask = score > threshold
return mask
# Fine-tuning loop
for x, y in loader:
# Forward
logits = model(x)
# Calculate loss
loss = criterion(logits, y)
# Prune
mask = prune(model.pv, prune_rate)
# Update
optim.zero_grad()
loss.backward()
with torch.no_grad():
model.pv.grad *= mask # Masked gradients
optim.step() # Update
The key aspects are:
- Initializing visual and key-value prompts
- Incorporating prompts into transformer layers
- Pruning prompts by thresholding gradients
- Updating prompts with masked gradients
This implements the overall EVPT framework for efficient transformer fine-tuning.