Optimizing Non-Autoregressive Transformers with Contrastive Learning

Chenxin An,Jiangtao Feng,Fei Huang,Xipeng Qiu,Lingpeng Kong
2023-06-02
Abstract:Non-autoregressive Transformers (NATs) reduce the inference latency of Autoregressive Transformers (ATs) by predicting words all at once rather than in sequential order. They have achieved remarkable progress in machine translation as well as many other applications. However, a long-standing challenge for NATs is the learning of multi-modality data distribution, which is the main cause of the performance gap between NATs and ATs. In this paper, we propose to ease the difficulty of modality learning via sampling from the model distribution instead of the data distribution. We derive contrastive constraints to stabilize the training process and integrate this resulting objective with the state-of-the-art NAT architecture DA-Transformer. Our model \method is examined on 3 different tasks, including machine translation, text summarization, and paraphrasing with 5 benchmarks. Results show that our approach outperforms previous non-autoregressive baselines by a significant margin and establishes new state-of-the-art results for non-autoregressive transformers on all the benchmarks.
Computation and Language
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper aims to address the challenges of non - autoregressive transformers (NATs) in multi - modal data distribution learning. Specifically: 1. **Difficulties in multi - modal data distribution learning**: - **Background**: NATs significantly reduce the inference latency by generating the entire sequence in parallel, rather than generating words one by one as autoregressive transformers (ATs) do. However, this parallel generation method makes it difficult for NATs to learn multi - modal data distribution, which is the main reason why the performance of NATs is inferior to that of ATs. - **Problem**: NATs are prone to generate outputs that mix different modalities, resulting in performance degradation. 2. **Limitations of existing methods**: - **Knowledge distillation**: Distill the original data through autoregressive models to reduce the number of modalities, but this introduces redundant pipelines. - **Dynamic programming training**: Alleviate the problem by introducing long decoder lengths and alignment - based training objectives, but with limited effectiveness. 3. **The proposed new method**: - **Contrastive learning optimization**: The authors propose a new method. By sampling from the model distribution instead of the data distribution, they use contrastive learning to stabilize the training process and combine it with the state - of - the - art NAT architecture DA - Transformer to propose a new model CODAT (Contrastive Optimization for DA - Transformer). - **Objective**: By optimizing the reverse KL divergence and the contrastive learning objective, alleviate the difficulties of NATs in multi - modal data distribution learning and improve their performance. ### Formula explanation - **Maximum Likelihood Estimation (MLE)**: \[ \text{KL}(p \parallel q)=\mathbb{E}_{y \sim p(\cdot|x)}\left[\log \frac{p(y|x)}{q(y|x)}\right]=-H(p)+\mathbb{E}_{y \sim p(\cdot|x)}\left[-\log q(y|x)\right] \] where \( H(p) \) represents the entropy of the data distribution and is a constant. - **The proposed generalized divergence \( D(q \parallel p) \)**: \[ D(q \parallel p)=\mathbb{E}_{y' \sim q(\cdot|x)}[M_{p,q}(y'|x)] \] where \( M_{p,q}(y'|x) \) measures the difference between the model distribution and the data distribution for a given sample \( y' \). - **Contrastive learning objective**: \[ L_{i,j}=\max \left\{0,-\log q(y'_i|x)+\log q(y'_j|x)+(i - j)\epsilon_{LB}\right\} \] where \( \epsilon_{LB} \) is a hyperparameter representing the lower bound of the reward gap. Through these methods, CODAT significantly outperforms previous non - autoregressive baseline models on multiple benchmarks and establishes new state - of - the - art results.