DPM-OT: A New Diffusion Probabilistic Model Based on Optimal Transport

Authors: Zezeng Li, ShengHao Li, Zhanpeng Wang, Na Lei, Zhongxuan Luo, Xianfeng Gu

Abstract: Sampling from diffusion probabilistic models (DPMs) can be viewed as a piecewise distribution transformation, which generally requires hundreds or thousands of steps of the inverse diffusion trajectory to get a high-quality image. Recent progress in designing fast samplers for DPMs achieves a trade-off between sampling speed and sample quality by knowledge distillation or adjusting the variance schedule or the denoising equation. However, it can’t be optimal in both aspects and often suffer from mode mixture in short steps. To tackle this problem, we innovatively regard inverse diffusion as an optimal transport (OT) problem between latents at different stages and propose the DPM-OT, a unified learning framework for fast DPMs with a direct expressway represented by OT map, which can generate high-quality samples within around 10 function evaluations. By calculating the semi-discrete optimal transport map between the data latents and the white noise, we obtain an expressway from the prior distribution to the data distribution, while significantly alleviating the problem of mode mixture. In addition, we give the error bound of the proposed method, which theoretically guarantees the stability of the algorithm. Extensive experiments validate the effectiveness and advantages of DPM-OT in terms of speed and quality (FID and mode mixture), thus representing an efficient solution for generative modeling. Source codes are available at DPM-OT

What, Why and How

Here is a summary of the key points from the paper:

What:

  • The paper proposes a new fast diffusion probabilistic model framework called DPM-OT, which incorporates optimal transport into diffusion models for efficient high-quality image generation.

Why:

  • Existing diffusion models require sampling over hundreds or thousands of steps, which is very slow. Prior fast diffusion models sacrifice sample quality or suffer from mode mixture.
  • This paper proposes using optimal transport to map directly from the prior noise distribution to target latents in a few steps, avoiding mode mixture and enabling fast sampling while maintaining quality.

How:

  • They compute the semi-discrete optimal transport (SDOT) map between the noise latents and target latents using the Brenier potential. This provides the optimal map in one step.
  • The SDOT map is discontinuous at singularities, preserving separation between modes/classes to avoid mixture.
  • After mapping with SDOT, only around 10 steps of diffusion are needed to refine the samples.
  • They prove a bound on the sampling error compared to standard diffusion models.
  • Experiments show DPM-OT produces higher quality samples with less mode mixture using only ~5-10 function evaluations, compared to 100s-1000s for standard models.

In summary, DPM-OT incorporates optimal transport into diffusion models to enable fast high-quality sampling while avoiding mode mixture, with theoretical guarantees on sampling error. This represents an efficient solution for generative modeling.

Main Contributions

Here are the main contributions of this paper:

  1. They propose DPM-OT, a new fast sampling framework for diffusion probabilistic models that incorporates optimal transport. This allows high-quality sample generation in around 10 steps.

  2. They compute the semi-discrete optimal transport map using the Brenier potential to map directly from the prior noise distribution to target data latents. This provides an “expressway” for fast sampling.

  3. The discontinuities in the optimal transport map preserve separation between modes/classes, avoiding mode mixture problems in generated samples.

  4. They theoretically analyze the single-step error and prove an upper bound on the error between the generated distribution and target data distribution.

  5. Extensive experiments validate DPM-OT outperforms state-of-the-art methods in sample quality and mode mixture with only 5-10 function evaluations needed.

In summary, the main contributions are:

  • A new DPM framework DPM-OT incorporating optimal transport for fast high-quality sampling.

  • Computing the optimal transport map between latents for an express sampling trajectory.

  • Theoretical analysis bounding sampling error.

  • Empirical demonstrations of faster sampling with improved quality and mode separation.

Method Section

Here is a summary of the method section:

The key idea is to incorporate optimal transport into diffusion models to create an “expressway” for fast high-quality sampling.

  1. Optimal Trajectory and Sampler
  • Define the optimal trajectory from latent x_T to x_M as the transport map g* that minimizes the transport cost between the two distributions.

  • This g* provides a direct map from prior noise to target latents in one step.

  • The DPM-OT sampler first maps x_T to x_M using g*, then performs M steps of diffusion starting from x_M to refine the sample.

  1. Computing g* via Semi-Discrete Optimal Transport
  • g* is computed as the gradient of the Brenier potential, which is the solution to a semi-discrete optimal transport problem.

  • The Brenier potential leads to a transport map g* that is discontinuous at singularities.

  • This preserves separation between modes/classes, avoiding mode mixture.

  • They use Monte Carlo sampling to approximate g*.

  1. Sampling Algorithm
  • Given g*, sampling is done by first mapping x_T to x_M, then performing M steps of diffusion using a pre-trained model s_theta.

  • M is a small number like 10.

  • This provides a fast expressway for high-quality sampling in very few steps.

  1. Error Analysis
  • Prove bounds on the single-step error and error between generated and target distributions.

  • Show the error is controllable and upper bounded by the standard diffusion model error.

So in summary, the key method is to use optimal transport to create an express sampling trajectory from noise to target latents, enabling fast high-quality sampling with theoretical guarantees.

High-Level Pseudo Code

Here is the high-level pseudo code for the key algorithms in the paper:

# Compute optimal transport map g* 
target_data = {y1, y2, ..., yN} # target data points
num_steps = M 
 
# Forward diffuse target data 
diffused_data = forward_diffusion(target_data, num_steps)
 
# Approximate OT map between noise and diffused_data
g = ot_map(noise, diffused_data) 
 
# DPM-OT sampling
x_T = sample_noise() 
x_M = g(x_T) # Map noise to target latents
for t in range(M, 0, -1):
  x_{t-1} = diffusion_step(x_t) # Refine sample
  
return x_0

The key steps are:

  1. Forward diffuse the target data
  2. Approximate the OT map between noise and diffused latents
  3. Sample by mapping noise to latents via OT, then refine with diffusion

This provides an expressway from noise to data latents for fast high-quality sampling.

Detailed Pseudo Code

Here is some more detailed pseudo code to implement the key algorithms in the paper:

# Compute OT map g* (Algorithm 1)
target_data = {y1, y2, ..., yN} 
num_steps M
diffused_data = forward_diffuse(target_data, M) 
 
# Monte Carlo OT
for num_iterations:
  
  sample noise points {x1, x2, ..., xN}
  
  compute Brenier potential u_h
  
  update u_h to minimize energy 
 
g = gradient(u_h) # OT map g*
 
# DPM-OT sampling (Algorithm 2) 
def sample(theta, steps_M, schedule):
 
  x_T = sample_gaussian_noise()
  
  # Map noise to target latents
  x_M = g(x_T)  
  
  # Refine with reverse diffusion
  for t in range(M, 0, -1):
    z = sample_gaussian_noise() 
    x_{t-1} = x_t + schedule.sigma(t)^2 * s_theta(x_t) + schedule.sigma(t) * z
  
  return x_0

Where:

  • forward_diffuse(): performs forward diffusion on data
  • u_h: Brenier potential approximated by Monte Carlo
  • schedule: diffusion schedule coefficients
  • s_theta: pretrained diffusion model

This implements the key steps of computing the OT map g* and sampling via mapping and reverse diffusion.