DPM-OT: A New Diffusion Probabilistic Model Based on Optimal Transport
Authors: Zezeng Li, ShengHao Li, Zhanpeng Wang, Na Lei, Zhongxuan Luo, Xianfeng Gu
Abstract: Sampling from diffusion probabilistic models (DPMs) can be viewed as a piecewise distribution transformation, which generally requires hundreds or thousands of steps of the inverse diffusion trajectory to get a high-quality image. Recent progress in designing fast samplers for DPMs achieves a trade-off between sampling speed and sample quality by knowledge distillation or adjusting the variance schedule or the denoising equation. However, it can’t be optimal in both aspects and often suffer from mode mixture in short steps. To tackle this problem, we innovatively regard inverse diffusion as an optimal transport (OT) problem between latents at different stages and propose the DPM-OT, a unified learning framework for fast DPMs with a direct expressway represented by OT map, which can generate high-quality samples within around 10 function evaluations. By calculating the semi-discrete optimal transport map between the data latents and the white noise, we obtain an expressway from the prior distribution to the data distribution, while significantly alleviating the problem of mode mixture. In addition, we give the error bound of the proposed method, which theoretically guarantees the stability of the algorithm. Extensive experiments validate the effectiveness and advantages of DPM-OT in terms of speed and quality (FID and mode mixture), thus representing an efficient solution for generative modeling. Source codes are available at DPM-OT
What, Why and How
Here is a summary of the key points from the paper:
What:
- The paper proposes a new fast diffusion probabilistic model framework called DPM-OT, which incorporates optimal transport into diffusion models for efficient high-quality image generation.
Why:
- Existing diffusion models require sampling over hundreds or thousands of steps, which is very slow. Prior fast diffusion models sacrifice sample quality or suffer from mode mixture.
- This paper proposes using optimal transport to map directly from the prior noise distribution to target latents in a few steps, avoiding mode mixture and enabling fast sampling while maintaining quality.
How:
- They compute the semi-discrete optimal transport (SDOT) map between the noise latents and target latents using the Brenier potential. This provides the optimal map in one step.
- The SDOT map is discontinuous at singularities, preserving separation between modes/classes to avoid mixture.
- After mapping with SDOT, only around 10 steps of diffusion are needed to refine the samples.
- They prove a bound on the sampling error compared to standard diffusion models.
- Experiments show DPM-OT produces higher quality samples with less mode mixture using only ~5-10 function evaluations, compared to 100s-1000s for standard models.
In summary, DPM-OT incorporates optimal transport into diffusion models to enable fast high-quality sampling while avoiding mode mixture, with theoretical guarantees on sampling error. This represents an efficient solution for generative modeling.
Main Contributions
Here are the main contributions of this paper:
-
They propose DPM-OT, a new fast sampling framework for diffusion probabilistic models that incorporates optimal transport. This allows high-quality sample generation in around 10 steps.
-
They compute the semi-discrete optimal transport map using the Brenier potential to map directly from the prior noise distribution to target data latents. This provides an “expressway” for fast sampling.
-
The discontinuities in the optimal transport map preserve separation between modes/classes, avoiding mode mixture problems in generated samples.
-
They theoretically analyze the single-step error and prove an upper bound on the error between the generated distribution and target data distribution.
-
Extensive experiments validate DPM-OT outperforms state-of-the-art methods in sample quality and mode mixture with only 5-10 function evaluations needed.
In summary, the main contributions are:
-
A new DPM framework DPM-OT incorporating optimal transport for fast high-quality sampling.
-
Computing the optimal transport map between latents for an express sampling trajectory.
-
Theoretical analysis bounding sampling error.
-
Empirical demonstrations of faster sampling with improved quality and mode separation.
Method Section
Here is a summary of the method section:
The key idea is to incorporate optimal transport into diffusion models to create an “expressway” for fast high-quality sampling.
- Optimal Trajectory and Sampler
-
Define the optimal trajectory from latent x_T to x_M as the transport map g* that minimizes the transport cost between the two distributions.
-
This g* provides a direct map from prior noise to target latents in one step.
-
The DPM-OT sampler first maps x_T to x_M using g*, then performs M steps of diffusion starting from x_M to refine the sample.
- Computing g* via Semi-Discrete Optimal Transport
-
g* is computed as the gradient of the Brenier potential, which is the solution to a semi-discrete optimal transport problem.
-
The Brenier potential leads to a transport map g* that is discontinuous at singularities.
-
This preserves separation between modes/classes, avoiding mode mixture.
-
They use Monte Carlo sampling to approximate g*.
- Sampling Algorithm
-
Given g*, sampling is done by first mapping x_T to x_M, then performing M steps of diffusion using a pre-trained model s_theta.
-
M is a small number like 10.
-
This provides a fast expressway for high-quality sampling in very few steps.
- Error Analysis
-
Prove bounds on the single-step error and error between generated and target distributions.
-
Show the error is controllable and upper bounded by the standard diffusion model error.
So in summary, the key method is to use optimal transport to create an express sampling trajectory from noise to target latents, enabling fast high-quality sampling with theoretical guarantees.
High-Level Pseudo Code
Here is the high-level pseudo code for the key algorithms in the paper:
# Compute optimal transport map g*
target_data = {y1, y2, ..., yN} # target data points
num_steps = M
# Forward diffuse target data
diffused_data = forward_diffusion(target_data, num_steps)
# Approximate OT map between noise and diffused_data
g = ot_map(noise, diffused_data)
# DPM-OT sampling
x_T = sample_noise()
x_M = g(x_T) # Map noise to target latents
for t in range(M, 0, -1):
x_{t-1} = diffusion_step(x_t) # Refine sample
return x_0
The key steps are:
- Forward diffuse the target data
- Approximate the OT map between noise and diffused latents
- Sample by mapping noise to latents via OT, then refine with diffusion
This provides an expressway from noise to data latents for fast high-quality sampling.
Detailed Pseudo Code
Here is some more detailed pseudo code to implement the key algorithms in the paper:
# Compute OT map g* (Algorithm 1)
target_data = {y1, y2, ..., yN}
num_steps M
diffused_data = forward_diffuse(target_data, M)
# Monte Carlo OT
for num_iterations:
sample noise points {x1, x2, ..., xN}
compute Brenier potential u_h
update u_h to minimize energy
g = gradient(u_h) # OT map g*
# DPM-OT sampling (Algorithm 2)
def sample(theta, steps_M, schedule):
x_T = sample_gaussian_noise()
# Map noise to target latents
x_M = g(x_T)
# Refine with reverse diffusion
for t in range(M, 0, -1):
z = sample_gaussian_noise()
x_{t-1} = x_t + schedule.sigma(t)^2 * s_theta(x_t) + schedule.sigma(t) * z
return x_0
Where:
forward_diffuse()
: performs forward diffusion on datau_h
: Brenier potential approximated by Monte Carloschedule
: diffusion schedule coefficientss_theta
: pretrained diffusion model
This implements the key steps of computing the OT map g* and sampling via mapping and reverse diffusion.