Enhancing Unimodal Latent Representations in Multimodal VAEs through Iterative Amortized Inference

Yuta Oshima,Masahiro Suzuki,Yutaka Matsuo
2024-10-15
Abstract:Multimodal variational autoencoders (VAEs) aim to capture shared latent representations by integrating information from different data modalities. A significant challenge is accurately inferring representations from any subset of modalities without training an impractical number (2^M) of inference networks for all possible modality combinations. Mixture-based models simplify this by requiring only as many inference models as there are modalities, aggregating unimodal inferences. However, they suffer from information loss when modalities are missing. Alignment-based VAEs address this by aligning unimodal inference models with a multimodal model through minimizing the Kullback-Leibler (KL) divergence but face issues due to amortization gaps, which compromise inference accuracy. To tackle these problems, we introduce multimodal iterative amortized inference, an iterative refinement mechanism within the multimodal VAE framework. This method overcomes information loss from missing modalities and minimizes the amortization gap by iteratively refining the multimodal inference using all available modalities. By aligning unimodal inference to this refined multimodal posterior, we achieve unimodal inferences that effectively incorporate multimodal information while requiring only unimodal inputs during inference. Experiments on benchmark datasets show that our approach improves inference performance, evidenced by higher linear classification accuracy and competitive cosine similarity, and enhances cross-modal generation, indicated by lower FID scores. This demonstrates that our method enhances inferred representations from unimodal inputs.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
### Problems the paper attempts to solve This paper aims to solve two main problems of Multimodal VAEs when dealing with missing - modality data: 1. **Information loss**: In the mixture model, when some modality data are missing, it will lead to the loss of modality - specific information, thus affecting the inference performance. 2. **Approximation gap**: In the alignment model, there is an amortization gap in multimodal approximate inference, which makes unimodal inference unable to perfectly approximate the real inference, thus affecting the accuracy of inference. To overcome these problems, the paper proposes an iterative amortized inference mechanism and implements it within the multimodal VAE framework, called multimodal iterative amortized inference. By iteratively optimizing the multimodal evidence lower bound (ELBO), this method can recover information from all modalities and reduce the approximation gap. Moreover, by minimizing the Kullback - Leibler (KL) divergence between unimodal inference and the multimodal iterative approximate posterior distribution, this method can obtain effective inference results containing multimodal information when only using unimodal input. ### Formula summary 1. **Evidence lower bound (ELBO) of multimodal VAE**: \[ L(\theta, \phi; X)=\mathbb{E}_{q_{\phi}(z|X)}[\log p_{\theta}(X|z)] - D_{\text{KL}}[q_{\phi}(z|X)\|p(z)] \] 2. **Multimodal ELBO in the mixture model**: \[ L_{M}(\theta, \phi; X)\geq\sum_{S: X_{S}\in P(X)}\omega_{S}\left(\mathbb{E}_{q_{\text{PoE}}^{\phi}(z|X_{S})}[\log p_{\theta}(X|z)] - D_{\text{KL}}[q_{\text{PoE}}^{\phi}(z|X_{S})\|p(z)]\right) \] 3. **Objective function of the alignment model**: \[ \mathbb{E}_{q_{\phi}(z|X)}[\log p_{\theta}(X|z)] - D_{\text{KL}}[q_{\phi}(z|X)\|p(z)]-\sum_{m = 1}^{M}\pi_{m}D_{\text{KL}}(q_{\phi}(z|X)\|q_{\lambda_{m}}(z|x_{m})) \] 4. **Update formula for iterative approximate inference**: \[ \mu_{m}^{t + 1},\sigma_{m}^{t + 1}=f_{w}(x_{m},\mu_{m}^{t},\sigma_{m}^{t},\nabla_{\mu_{m}^{t}}L,\nabla_{\sigma_{m}^{t}}L) \] 5. **Alignment of unimodal inference and multimodal iterative approximate inference**: \[ D(\lambda; X)=\sum_{m = 1}^{M}D_{\text{KL}}[q_{\phi}(z_{T}|x_{m})\|q_{\lambda_{m}}(z|x_{m})] \] ### Experimental results The experiments were carried out on two standard benchmark datasets: MNIST - SVHN - Text and Caltech Birds (CUB). The results show that the proposed method significantly improves the inference performance, specifically manifested in higher linear classification accuracy and greater cosine similarity of latent representations. In addition, the learned representations effectively capture the distribution of other modalities, which is verified by the lower Fréchet Inception Distance (FID) score in cross - modal generation tasks.