Learned Thresholds Token Merging and Pruning for Vision Transformers
Authors: Maxim Bonnaerens, Joni Dambre
Abstract: Vision transformers have demonstrated remarkable success in a wide range of computer vision tasks over the last years. However, their high computational costs remain a significant barrier to their practical deployment. In particular, the complexity of transformer models is quadratic with respect to the number of input tokens. Therefore techniques that reduce the number of input tokens that need to be processed have been proposed. This paper introduces Learned Thresholds token Merging and Pruning (LTMP), a novel approach that leverages the strengths of both token merging and token pruning. LTMP uses learned threshold masking modules that dynamically determine which tokens to merge and which to prune. We demonstrate our approach with extensive experiments on vision transformers on the ImageNet classification task. Our results demonstrate that LTMP achieves state-of-the-art accuracy across reduction rates while requiring only a single fine-tuning epoch, which is an order of magnitude faster than previous methods. Code is available at ltmp .
What, Why and How
Here is a summary of the key points from this paper:
What:
- The paper introduces Learned Thresholds token Merging and Pruning (LTMP), a new method to reduce the computational cost of vision transformers.
Why:
-
Vision transformers have high computational costs due to their quadratic complexity with respect to the number of input tokens. Reducing tokens is an effective way to reduce cost.
-
Previous token reduction methods like pruning discard information while merging preserves it. This paper combines both to get the benefits of each.
How:
-
LTMP has learned threshold modules that determine which tokens to prune and which to merge in each layer.
-
Thresholds are learned using a budget-aware loss function that targets a specific FLOPs reduction. This allows creation of models of any desired size.
-
LTMP needs just 2 learnable parameters per layer so converges in 1 epoch, much faster than other methods needing 30-100 epochs.
-
Experiments on ImageNet show LTMP matches or exceeds accuracy of other methods across model sizes while needing far fewer fine-tuning epochs. It also outperforms just using merging or just pruning.
In summary, LTMP introduces an efficient way to create reduced vision transformers of any size by combining pruning and merging with learned thresholds, reaching high accuracy with minimal fine-tuning.
Main Contributions
The main contributions of this paper are:
-
Proposes to combine token merging and token pruning to reduce vision transformers, getting benefits of both techniques.
-
Introduces learned threshold masking modules with only 2 learnable parameters per transformer block. This allows thresholds to be learned for merging and pruning.
-
Defines a budget-aware training loss function with a target FLOPs reduction. This enables creating models of any desired size by distributing reduction across layers.
-
Achieves state-of-the-art accuracy with only 1 epoch of fine-tuning, which is orders of magnitude faster convergence than prior methods.
-
Comprehensive experiments analyzing design choices and comparing to prior work on ImageNet classification task using DeiT and ViT models.
-
Visualizations showing how merging and pruning distribute across layers and how they reduce patches/tokens for sample images.
In summary, the key contributions are efficiently combining pruning and merging with learned thresholds, the budget-aware loss function, much faster convergence, extensive comparisons and analysis, and visualizations.
Method Section
Here is a summary of the method section from the paper:
-
LTMP contains learned threshold token merging (LTM) and learned threshold token pruning (LTP) components added between the MSA and MLP in each transformer block.
-
LTP uses an importance score for each token. A learned threshold determines which tokens to prune by masking based on importance > threshold.
-
Masking is made differentiable using a straight-through estimator. The attention mechanism is modified to ignore masked (pruned) tokens.
-
LTM is similar but uses a token similarity score and threshold to determine which tokens to merge by averaging.
-
The overall training objective combines cross-entropy loss with a regularization loss based on the difference between target and actual FLOPs reduction.
-
The FLOPs reduction factor is computed from the layer-wise token masks and transformer component FLOPs formulas.
-
Only the thresholds are updated during training while other parameters are fixed. This allows convergence in 1 epoch.
-
The mean column attention score is used as the importance score for pruning. Merging is done before pruning as merging is more accurate.
-
Experiments use image classification on ImageNet with ViT and DeiT models. Hyperparameters are tuned on a small validation set.
In summary, the key aspects of the method are the learned threshold modules for merging and pruning, the modified attention mechanism, the budget-aware loss function, and training for just 1 epoch to fine-tune the thresholds.
High-Level Pseudo Code
Here is the high-level pseudo code for the LTMP method:
# Input: pretrained vision transformer model
# Hyperparameters: reduction target r_target, loss weight lambda
# Add LTP and LTM modules to each transformer block
for block in model.blocks:
block.add_module('LTP', LTP(reduction_target))
block.add_module('LTM', LTM(reduction_target))
# Train loop
for image, label in dataloader:
# Forward pass
output = model(image)
# Compute loss
ce_loss = cross_entropy(output, label)
flops_reduction = compute_flops_reduction(model)
reg_loss = (r_target - flops_reduction)**2
loss = ce_loss + lambda * reg_loss
# Update only LTP and LTM thresholds
thresholds = get_ltp_ltm_thresholds(model)
thresholds.backward(loss)
update(thresholds.parameters())
# After 1 epoch, remove unnecessary params and convert to inference model
The key steps are:
- Add learned pruning (LTP) and merging (LTM) modules to transformer blocks
- Forward pass to compute cross-entropy and regularization loss
- Update only thresholds in LTP and LTM modules
- After 1 epoch, export model for inference
This shows how LTMP modifies the model, the budget-aware loss, and efficient fine-tuning of just the thresholds.
Detailed Pseudo Code
Here is more detailed pseudo code to implement the LTMP method:
import torch
import torch.nn as nn
from vit import ViT # pretrained vision transformer
# LTP module
class LTP(nn.Module):
def __init__(self, reduction_target):
super().__init__()
self.threshold = nn.Parameter(torch.zeros(1)) # learnable threshold
self.reduction_target = reduction_target
def forward(self, x, mask):
# Compute token importance scores
importance_scores = self.get_importance_scores(x)
# Mask tokens based on threshold
threshold_mask = get_threshold_mask(importance_scores, self.threshold)
# Update mask
mask = update_mask(mask, threshold_mask)
return x, mask
# LTM module
class LTM(nn.Module):
def __init__(self, reduction_target):
... # similar to LTP
def forward(self, x, mask):
# Compute token similarity scores
similarity_scores = self.get_similarity_scores(x)
# Mask tokens based on threshold
threshold_mask = get_threshold_mask(similarity_scores, self.threshold)
# Merge masked tokens
x = merge_tokens(x, threshold_mask)
return x, mask
# Compute FLOPs reduction
def compute_flops_reduction(model):
# Use formulas from paper to compute r_FLOPs
return r_FLOPs
# Training loop
for img, label in dataloader:
output = model(img)
ce_loss = cross_entropy(output, label)
r_FLOPs = compute_flops_reduction(model)
reg_loss = (r_target - r_FLOPs)**2
loss = ce_loss + lambda * reg_loss
thresholds = get_thresholds(model) # LTP and LTM thresholds
thresholds.backward(loss)
update(thresholds)
This shows the key components needed to implement LTMP including the threshold modules, modified attention, FLOPs computation, and the training loop.