Diffusion Model as Representation Learner

Authors: Xingyi Yang, Xinchao Wang

Abstract: Diffusion Probabilistic Models (DPMs) have recently demonstrated impressive results on various generative tasks.Despite its promises, the learned representations of pre-trained DPMs, however, have not been fully understood. In this paper, we conduct an in-depth investigation of the representation power of DPMs, and propose a novel knowledge transfer method that leverages the knowledge acquired by generative DPMs for recognition tasks. Our study begins by examining the feature space of DPMs, revealing that DPMs are inherently denoising autoencoders that balance the representation learning with regularizing model capacity. To this end, we introduce a novel knowledge transfer paradigm named RepFusion. Our paradigm extracts representations at different time steps from off-the-shelf DPMs and dynamically employs them as supervision for student networks, in which the optimal time is determined through reinforcement learning. We evaluate our approach on several image classification, semantic segmentation, and landmark detection benchmarks, and demonstrate that it outperforms state-of-the-art methods. Our results uncover the potential of DPMs as a powerful tool for representation learning and provide insights into the usefulness of generative models beyond sample generation. The code is available at \url{Repfusion}.

What, Why and How

Here is a summary of the key points from this paper:

What:

  • This paper investigates the representation learning capabilities of Diffusion Probabilistic Models (DPMs) for recognition tasks. DPMs have shown impressive results for generative modeling but their utility for representation learning is not well understood.

  • The authors establish a connection between DPMs and denoising autoencoders (DAEs), showing DPMs balance representation learning and model regularization.

  • They propose a knowledge transfer method called RepFusion that distills features from pre-trained DPMs to improve student networks on recognition tasks like classification and segmentation.

Why:

  • DPMs have a latent space that evolves over diffusion timesteps. Theoretical analysis shows intermediate features are most informative before later regularization kicks in.

  • But finding the optimal timestep for distillation is challenging. The authors use reinforcement learning to dynamically select timesteps based on feature informativeness.

  • This allows tapping into DPMs’ representation power to boost performance on discriminative downstream tasks.

How:

  • RepFusion extracts features from DPMs and uses distillation loss to transfer knowledge to student networks.

  • A policy network optimized with REINFORCE learns to pick optimal timesteps based on an informativeness reward.

  • Extensive experiments on image classification, segmentation and landmark detection show RepFusion outperforms baselines by effectively utilizing DPM representations.

In summary, this work provides theoretical and empirical evidence for the representation learning potential of DPMs. The proposed RepFusion method offers a way to leverage pre-trained DPMs to improve recognition tasks through dynamic knowledge distillation.

Main Contributions

Based on my reading, the main contributions of this paper are:

  1. Investigating the representation learning capabilities of diffusion models, which has been relatively unexplored. The authors provide theoretical analysis and experiments to demonstrate DPMs can learn meaningful features for recognition.

  2. Establishing the relationship between DPMs and denoising autoencoders. This insight connects DPMs to a well-studied class of models for representation learning.

  3. Proposing a novel knowledge transfer method called RepFusion that effectively leverages representations from pre-trained DPMs to improve discriminative tasks like classification and segmentation.

  4. Introducing a reinforcement learning based approach to dynamically determine the optimal timestep for distilling knowledge from DPMs. This aligns with the theoretical property of DPM features evolving over time.

  5. Comprehensive experiments validating RepFusion on image classification, segmentation and landmark detection benchmarks. The consistent improvements demonstrate the power of DPM representations.

  6. Providing both theoretical and empirical evidence that generative models like DPMs can go beyond just sample synthesis and serve as useful tools for representation learning as well.

In summary, the key contributions are establishing the representation learning capabilities of DPMs, proposing the RepFusion distillation paradigm to transfer this knowledge, and extensive experiments demonstrating effectiveness across multiple vision tasks. The paper opens up an exciting direction for exploiting generative models to improve discriminative performance.

Method Section

Here is a summary of the method section from the paper:

The paper proposes RepFusion, a novel knowledge transfer approach to leverage pre-trained Diffusion Probabilistic Models (DPMs) for improving recognition tasks.

The key steps are:

  1. Distill knowledge from DPM teacher:
  • Extract features from DPM at timestep t: z^(t) = s(x, t; θ∗)
  • Distill student f to mimic z^(t) by minimizing distillation loss L_kd
  • Use distilled f to initialize task network
  1. Reinforced time selection:
  • Formulate timestep selection as RL problem
  • Reward based on feature informativeness
  • Policy network π optimizes timestep t ∼ Categorical(π(t|x))
  • Decoder g predicts labels from z^(t) to estimate reward
  1. Optimization:
  • Update policy π to maximize expected reward
  • Update student f for distillation loss
  • Update decoder g for task loss

So in summary, RepFusion distills representations from DPMs into student networks by using reinforcement learning to pick the optimal timestep. This allows tapping into the representation power of DPMs to boost performance on downstream tasks. The method is simple yet effective for transferring knowledge from the generative model for recognition.

High-Level Pseudo Code

Here is the high-level pseudo code for the RepFusion method proposed in the paper:

# Distill knowledge from DPM teacher
for x, y in dataloader:
  
  # Sample timestep  
  t ~ Categorical(π(t|x)) 
  
  # Get features
  z_t = s(x, t; θ∗)  
 
  # Distill student f
  loss = L_kd(z_t, f(x))
  update_f(loss) 
 
# Optimize timestep selection  
for x, y in dataloader:
 
  # Sample timestep
  t ~ Categorical(π(t|x))  
  
  # Get features
  z_t = s(x, t; θ∗)
  
  # Calculate reward 
  reward = -L_task(y, g(z_t))
  
  # Update policy  
  update_π(reward)
  
  # Update decoder
  loss = L_task(y, g(z_t))
  update_g(loss)
 
# Fine-tune student
for x, y in dataloader:
 
  # Get features
  z = f(x)  
  
  # Calculate task loss
  loss = L_task(y, z)
 
  # Update student 
  update_f(loss)

The key steps are:

  1. Distill knowledge from DPM teacher by mimicking features at sampled timesteps
  2. Optimize timestep selection policy to maximize reward based on feature informativeness
  3. Fine-tune distilled student model on downstream task

This allows transferring useful representations from the DPM teacher to improve the student model.

Detailed Pseudo Code

Here is a more detailed pseudo code implementation of the RepFusion method:

# Models
teacher = DPM() # Pre-trained diffusion model
student = CNN() # Student model to be distilled
policy = MLP() # Policy network for timestep selection
decoder = MLP() # Decoder for reward estimation
 
# Optimizers
opt_student = SGD(student.parameters()) 
opt_policy = SGD(policy.parameters())
opt_decoder = SGD(decoder.parameters())
 
# Hyperparameters
T = 1000 # Total timesteps
lambda_H = 0.1 # Entropy regularization 
 
# Distillation stage
for x, y in dataloader:
 
  # Sample timestep
  t = policy(x) # t ~ Categorical(π(t|x))
  
  # Get features
  z_t = teacher.encode(x, t)
  
  # Distillation loss
  loss = ||z_t - student(x)||22 
 
  # Update student
  loss.backward()
  opt_student.step()
  opt_student.zero_grad()
 
# Reinforce training  
for x, y in dataloader:
 
  # Sample timestep 
  t = policy(x)
 
  # Get features
  z_t = teacher.encode(x, t)
 
  # Reward
  pred = decoder(z_t)
  reward = -cross_entropy_loss(pred, y)
 
  # Policy loss
  loss = (reward - lambda_H * log_prob(t)).mean() 
 
  # Update policy
  loss.backward()
  opt_policy.step()
  opt_policy.zero_grad()
 
  # Update decoder
  pred = decoder(z_t)
  loss = cross_entropy_loss(pred, y)
  
  # Update decoder
  loss.backward()
  opt_decoder.step()
  opt_decoder.zero_grad()
  
# Fine-tuning
for x, y in dataloader:
  
  # Get student features
  z = student(x)  
 
  # Task loss
  loss = cross_entropy_loss(z, y)
 
  # Update student
  loss.backward()
  opt_student.step()
  opt_student.zero_grad()

This implements the key components of RepFusion:

  1. Knowledge distillation from DPM to student
  2. Reinforce training for optimal timestep selection
  3. Fine-tuning student model on downstream task

The pseudo code covers the model definitions, optimization, and overall training flow.