Improving Gradient Flow with Unrolled Highway Expectation Maximization

Chonghyuk Song,Eunseok Kim,Inwook Shim
DOI: https://doi.org/10.48550/arXiv.2012.04926
2020-12-09
Abstract:Integrating model-based machine learning methods into deep neural architectures allows one to leverage both the expressive power of deep neural nets and the ability of model-based methods to incorporate domain-specific knowledge. In particular, many works have employed the expectation maximization (EM) algorithm in the form of an unrolled layer-wise structure that is jointly trained with a backbone neural network. However, it is difficult to discriminatively train the backbone network by backpropagating through the EM iterations as they are prone to the vanishing gradient problem. To address this issue, we propose Highway Expectation Maximization Networks (HEMNet), which is comprised of unrolled iterations of the generalized EM (GEM) algorithm based on the Newton-Rahpson method. HEMNet features scaled skip connections, or highways, along the depths of the unrolled architecture, resulting in improved gradient flow during backpropagation while incurring negligible additional computation and memory costs compared to standard unrolled EM. Furthermore, HEMNet preserves the underlying EM procedure, thereby fully retaining the convergence properties of the original EM algorithm. We achieve significant improvement in performance on several semantic segmentation benchmarks and empirically show that HEMNet effectively alleviates gradient decay.
Machine Learning,Computer Vision and Pattern Recognition,Neural and Evolutionary Computing
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the vanishing gradient problem encountered when jointly training model - based machine learning methods (especially the Expectation - Maximization algorithm, EM) in deep neural networks. Specifically: 1. **Background and Challenges**: - The EM algorithm is widely used in the field of statistical learning to solve the maximum - likelihood estimation of latent variable models. - In recent years, many studies have attempted to combine the EM algorithm with deep neural networks and achieve joint training by unrolling EM iteration steps as network layers. - However, this joint training method is prone to encounter the vanishing gradient problem during the back - propagation process, making it difficult to effectively train the backbone network. 2. **Proposed Method**: - To solve this problem, the authors proposed **Highway Expectation Maximization Networks (HEMNet)**. - The core idea of HEMNet is to replace the traditional M - step with the Newton - Raphson method of the Generalized EM (GEM) algorithm and introduce weighted skip connections, i.e., highways. - This design not only improves the gradient flow during the back - propagation process but also retains the convergence and computational efficiency of the EM algorithm. 3. **Improvement Effects**: - Verified by experiments, HEMNet significantly improves performance in multiple semantic segmentation benchmark tests and effectively alleviates the gradient decay problem. In summary, this paper aims to solve the vanishing gradient problem when unrolling EM iterations in deep neural networks by introducing the HEMNet structure, thereby improving the training effect and performance of the model. ### Formula Summary - **E - step and M - step of the EM algorithm**: \[ \text{E - step: } \gamma_{nk}^{(t + 1)}=\frac{\exp\left(-\frac{\|x_n-\mu_k^{(t)}\|^2}{\sigma^2}\right)}{\sum_{j = 1}^K\exp\left(-\frac{\|x_n-\mu_j^{(t)}\|^2}{\sigma^2}\right)} \] \[ \text{M - step: } \mu_k^{(t + 1)}=\frac{1}{N_k}\sum_{n = 1}^N\gamma_{nk}^{(t + 1)}x_n \] - **N - step of the GEM algorithm**: \[ \mu_k^{(t + 1)}=\mu_k^{(t)}-\eta\left(\frac{\partial^2Q}{\partial\mu_k\partial\mu_k}\right)^{-1}\frac{\partial Q}{\partial\mu_k} \] \[ =(1 - \eta)\mu_k^{(t)}+\eta\left(\frac{1}{N_k^{(t + 1)}}\sum_{n = 1}^N\gamma_{nk}^{(t + 1)}x_n\right) \] These formulas show the key steps of the EM and GEM algorithms and how to improve the gradient flow by introducing weighted skip connections.