Learned Thresholds Token Merging and Pruning for Vision Transformers

Authors: Maxim Bonnaerens, Joni Dambre

Abstract: Vision transformers have demonstrated remarkable success in a wide range of computer vision tasks over the last years. However, their high computational costs remain a significant barrier to their practical deployment. In particular, the complexity of transformer models is quadratic with respect to the number of input tokens. Therefore techniques that reduce the number of input tokens that need to be processed have been proposed. This paper introduces Learned Thresholds token Merging and Pruning (LTMP), a novel approach that leverages the strengths of both token merging and token pruning. LTMP uses learned threshold masking modules that dynamically determine which tokens to merge and which to prune. We demonstrate our approach with extensive experiments on vision transformers on the ImageNet classification task. Our results demonstrate that LTMP achieves state-of-the-art accuracy across reduction rates while requiring only a single fine-tuning epoch, which is an order of magnitude faster than previous methods. Code is available at ltmp .

What, Why and How

Here is a summary of the key points from this paper:

What:

  • The paper introduces Learned Thresholds token Merging and Pruning (LTMP), a new method to reduce the computational cost of vision transformers.

Why:

  • Vision transformers have high computational costs due to their quadratic complexity with respect to the number of input tokens. Reducing tokens is an effective way to reduce cost.

  • Previous token reduction methods like pruning discard information while merging preserves it. This paper combines both to get the benefits of each.

How:

  • LTMP has learned threshold modules that determine which tokens to prune and which to merge in each layer.

  • Thresholds are learned using a budget-aware loss function that targets a specific FLOPs reduction. This allows creation of models of any desired size.

  • LTMP needs just 2 learnable parameters per layer so converges in 1 epoch, much faster than other methods needing 30-100 epochs.

  • Experiments on ImageNet show LTMP matches or exceeds accuracy of other methods across model sizes while needing far fewer fine-tuning epochs. It also outperforms just using merging or just pruning.

In summary, LTMP introduces an efficient way to create reduced vision transformers of any size by combining pruning and merging with learned thresholds, reaching high accuracy with minimal fine-tuning.

Main Contributions

The main contributions of this paper are:

  • Proposes to combine token merging and token pruning to reduce vision transformers, getting benefits of both techniques.

  • Introduces learned threshold masking modules with only 2 learnable parameters per transformer block. This allows thresholds to be learned for merging and pruning.

  • Defines a budget-aware training loss function with a target FLOPs reduction. This enables creating models of any desired size by distributing reduction across layers.

  • Achieves state-of-the-art accuracy with only 1 epoch of fine-tuning, which is orders of magnitude faster convergence than prior methods.

  • Comprehensive experiments analyzing design choices and comparing to prior work on ImageNet classification task using DeiT and ViT models.

  • Visualizations showing how merging and pruning distribute across layers and how they reduce patches/tokens for sample images.

In summary, the key contributions are efficiently combining pruning and merging with learned thresholds, the budget-aware loss function, much faster convergence, extensive comparisons and analysis, and visualizations.

Method Section

Here is a summary of the method section from the paper:

  • LTMP contains learned threshold token merging (LTM) and learned threshold token pruning (LTP) components added between the MSA and MLP in each transformer block.

  • LTP uses an importance score for each token. A learned threshold determines which tokens to prune by masking based on importance > threshold.

  • Masking is made differentiable using a straight-through estimator. The attention mechanism is modified to ignore masked (pruned) tokens.

  • LTM is similar but uses a token similarity score and threshold to determine which tokens to merge by averaging.

  • The overall training objective combines cross-entropy loss with a regularization loss based on the difference between target and actual FLOPs reduction.

  • The FLOPs reduction factor is computed from the layer-wise token masks and transformer component FLOPs formulas.

  • Only the thresholds are updated during training while other parameters are fixed. This allows convergence in 1 epoch.

  • The mean column attention score is used as the importance score for pruning. Merging is done before pruning as merging is more accurate.

  • Experiments use image classification on ImageNet with ViT and DeiT models. Hyperparameters are tuned on a small validation set.

In summary, the key aspects of the method are the learned threshold modules for merging and pruning, the modified attention mechanism, the budget-aware loss function, and training for just 1 epoch to fine-tune the thresholds.

High-Level Pseudo Code

Here is the high-level pseudo code for the LTMP method:

# Input: pretrained vision transformer model
# Hyperparameters: reduction target r_target, loss weight lambda
 
# Add LTP and LTM modules to each transformer block
for block in model.blocks:
  block.add_module('LTP', LTP(reduction_target)) 
  block.add_module('LTM', LTM(reduction_target))
 
# Train loop
for image, label in dataloader:
 
  # Forward pass  
  output = model(image)
  
  # Compute loss  
  ce_loss = cross_entropy(output, label)
  flops_reduction = compute_flops_reduction(model) 
  reg_loss = (r_target - flops_reduction)**2
  loss = ce_loss + lambda * reg_loss
 
  # Update only LTP and LTM thresholds
  thresholds = get_ltp_ltm_thresholds(model)
  thresholds.backward(loss)
  update(thresholds.parameters())
 
# After 1 epoch, remove unnecessary params and convert to inference model

The key steps are:

  1. Add learned pruning (LTP) and merging (LTM) modules to transformer blocks
  2. Forward pass to compute cross-entropy and regularization loss
  3. Update only thresholds in LTP and LTM modules
  4. After 1 epoch, export model for inference

This shows how LTMP modifies the model, the budget-aware loss, and efficient fine-tuning of just the thresholds.

Detailed Pseudo Code

Here is more detailed pseudo code to implement the LTMP method:

import torch
import torch.nn as nn
from vit import ViT # pretrained vision transformer 
 
# LTP module
class LTP(nn.Module):
  
  def __init__(self, reduction_target):
    super().__init__()
    self.threshold = nn.Parameter(torch.zeros(1)) # learnable threshold 
    self.reduction_target = reduction_target
  
  def forward(self, x, mask):
    # Compute token importance scores
    importance_scores = self.get_importance_scores(x)
    
    # Mask tokens based on threshold
    threshold_mask = get_threshold_mask(importance_scores, self.threshold)
    
    # Update mask
    mask = update_mask(mask, threshold_mask) 
    
    return x, mask
 
# LTM module  
class LTM(nn.Module):
 
  def __init__(self, reduction_target):
    ... # similar to LTP
 
  def forward(self, x, mask):
    # Compute token similarity scores
    similarity_scores = self.get_similarity_scores(x) 
 
    # Mask tokens based on threshold 
    threshold_mask = get_threshold_mask(similarity_scores, self.threshold)
    
    # Merge masked tokens
    x = merge_tokens(x, threshold_mask)
 
    return x, mask
 
# Compute FLOPs reduction
def compute_flops_reduction(model):
  # Use formulas from paper to compute r_FLOPs
  return r_FLOPs
 
# Training loop
for img, label in dataloader:
  
  output = model(img)
 
  ce_loss = cross_entropy(output, label)
 
  r_FLOPs = compute_flops_reduction(model)
 
  reg_loss = (r_target - r_FLOPs)**2
  
  loss = ce_loss + lambda * reg_loss
 
  thresholds = get_thresholds(model) # LTP and LTM thresholds
 
  thresholds.backward(loss)
 
  update(thresholds) 

This shows the key components needed to implement LTMP including the threshold modules, modified attention, FLOPs computation, and the training loop.