ELBOing Stein: Variational Bayes with Stein Mixture Inference

Ola Rønning,Eric Nalisnick,Christophe Ley,Padhraic Smyth,Thomas Hamelryck
2024-10-30
Abstract:Stein variational gradient descent (SVGD) [Liu and Wang, 2016] performs approximate Bayesian inference by representing the posterior with a set of particles. However, SVGD suffers from variance collapse, i.e. poor predictions due to underestimating uncertainty [Ba et al., 2021], even for moderately-dimensional models such as small Bayesian neural networks (BNNs). To address this issue, we generalize SVGD by letting each particle parameterize a component distribution in a mixture model. Our method, Stein Mixture Inference (SMI), optimizes a lower bound to the evidence (ELBO) and introduces user-specified guides parameterized by particles. SMI extends the Nonlinear SVGD framework [Wang and Liu, 2019] to the case of variational Bayes. SMI effectively avoids variance collapse, judging by a previously described test developed for this purpose, and performs well on standard data sets. In addition, SMI requires considerably fewer particles than SVGD to accurately estimate uncertainty for small BNNs. The synergistic combination of NSVGD, ELBO optimization and user-specified guides establishes a promising approach towards variational Bayesian inference in the case of tall and wide data.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the variance collapse problem encountered by the Stein Variational Gradient Descent (SVGD) method in Bayesian inference. Specifically, SVGD represents the posterior distribution by a set of particles, but when dealing with models of medium dimensions (such as small Bayesian neural networks, BNNs), it is prone to underestimate uncertainty, resulting in poor prediction effects. To overcome this problem, the author introduced a new method - Stein Mixture Inference (SMI), which extends SVGD by allowing each particle to parameterize a component distribution in a mixture model. ### Main Contributions 1. **Introduction of SMI**: Apply the NSVGD framework to variational Bayesian inference, so that each particle can parameterize a component distribution in a mixture model. 2. **Experimental Verification**: Experiments prove that SMI is more efficient than SVGD in estimating uncertainty, especially on small - to medium - sized models. 3. **Avoidance of Variance Collapse**: It is shown that SMI does not have the variance collapse problem in small - to medium - sized models (such as small BNNs). ### Method Overview The core idea of SMI is to approximate the posterior distribution through a mixture model instead of directly representing posterior samples with particles. Specifically, SMI optimizes a lower bound (ELBO), which measures the approximation degree of the mixture model to the true posterior distribution. The formula is as follows: \[ L(\rho_m) = \frac{1}{m} \sum_{\ell = 1}^m \mathbb{E}_{q(\theta|\psi_\ell)} \left[ \log \frac{p(\theta, D)}{q(\theta|\rho_m)} \right] \leq \log p(D) \] Here, \( q(\theta|\rho_m) \) is a uniform mixture model composed of \( m \) user - specified guiding distributions, parameterized by particles \( \{\psi_\ell\}_{\ell = 1}^m \). In addition, SMI also introduces a weighted entropy term \( \alpha H[\rho_m] \) to encourage particle diversity and prevent mode collapse. The final optimization objective is: \[ \rho_m^* = \arg \max_{\rho_m} L(\rho_m) + \alpha H[\rho_m] \] ### Experimental Results Through experiments, the author proves that SMI is superior to SVGD in the following aspects: - **Higher Particle Efficiency**: SMI requires fewer particles to accurately estimate uncertainty. - **Avoidance of Variance Collapse**: On small - to medium - sized models, SMI does not have the variance collapse phenomenon. - **Better Performance**: It performs better than SVGD and ASVGD on standard datasets. In conclusion, this paper proposes a new variational Bayesian inference method - Stein Mixture Inference (SMI), which effectively solves the variance collapse problem in SVGD and verifies its superiority in multiple experiments.