Improved sampling via learned diffusions

Lorenz Richter,Julius Berner
2024-05-23
Abstract:Recently, a series of papers proposed deep learning-based approaches to sample from target distributions using controlled diffusion processes, being trained only on the unnormalized target densities without access to samples. Building on previous work, we identify these approaches as special cases of a generalized Schrödinger bridge problem, seeking a stochastic evolution between a given prior distribution and the specified target. We further generalize this framework by introducing a variational formulation based on divergences between path space measures of time-reversed diffusion processes. This abstract perspective leads to practical losses that can be optimized by gradient-based algorithms and includes previous objectives as special cases. At the same time, it allows us to consider divergences other than the reverse Kullback-Leibler divergence that is known to suffer from mode collapse. In particular, we propose the so-called log-variance loss, which exhibits favorable numerical properties and leads to significantly improved performance across all considered approaches.
Machine Learning,Optimization and Control,Probability
What problem does this paper attempt to address?
This paper mainly discusses how to improve sampling methods through the learning-based diffusion process, especially in the case where only the target density function is available without sample data. The authors unify existing methods into a framework, namely the variational form of control diffusion process based on path space and time reversal. They propose a new loss function, the log variance loss, which overcomes the problem of pattern collapse that may be caused by the reverse Kullback-Leibler (KL) divergence and improves numerical stability and performance. The main contributions of the paper include: 1. Providing a unified framework that connects learning-based diffusion sampling with path space and time reversal SDEs, for the first time linking methods such as bridge problem, time reversal diffusion sampler, and denoising diffusion sampler. 2. Allowing the use of any divergence optimization objective from a path space perspective, not just the reverse KL divergence. 3. Introducing the log variance divergence, avoiding differentiation through SDE solvers, and striking a balance between exploration and exploitation, significantly improving numerical stability and performance. In the paper, the authors compare the performance of the reverse KL divergence and the log variance divergence in different sampling methods (such as bridge methods and reference process-based methods), demonstrating the advantages of the log variance divergence in preventing pattern collapse, improving numerical stability, and performance. Additionally, they provide numerical experiments that prove the superiority of the log variance divergence over the reverse KL divergence on various benchmark problems.