Faster Speech-LLaMA Inference with Multi-token Prediction

Desh Raj,Gil Keren,Junteng Jia,Jay Mahadeokar,Ozlem Kalinli
2024-09-12
Abstract:Large language models (LLMs) have become proficient at solving a wide variety of tasks, including those involving multi-modal inputs. In particular, instantiating an LLM (such as LLaMA) with a speech encoder and training it on paired data imparts speech recognition (ASR) abilities to the decoder-only model, hence called Speech-LLaMA. Nevertheless, due to the sequential nature of auto-regressive inference and the relatively large decoder, Speech-LLaMA models require relatively high inference time. In this work, we propose to speed up Speech-LLaMA inference by predicting multiple tokens in the same decoding step. We explore several model architectures that enable this, and investigate their performance using threshold-based and verification-based inference strategies. We also propose a prefix-based beam search decoding method that allows efficient minimum word error rate (MWER) training for such models. We evaluate our models on a variety of public benchmarks, where they reduce the number of decoder calls by ~3.2x while maintaining or improving WER performance.
Audio and Speech Processing,Sound
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the slow inference speed of the Speech - LLaMA model in the Automatic Speech Recognition (ASR) task. Specifically, due to the sequential nature of the autoregressive decoding process of the Speech - LLaMA model and its relatively large decoder, it results in a high inference time. To solve this problem, the author proposes to accelerate the inference process of Speech - LLaMA through multi - token prediction. ### Specific description of the problem 1. **Long inference time**: The autoregressive decoder of the Speech - LLaMA model needs to generate each token one by one, which leads to a long inference time. 2. **Memory bandwidth limitation**: Large - scale decoders (such as LLM) need to be loaded into the computational memory every time a token is generated, which makes the inference process limited by the memory bandwidth. 3. **Low utilization of computational resources**: The traditional autoregressive decoding method cannot fully utilize the available computational resources, resulting in low inference efficiency. ### Solutions To accelerate the inference process of Speech - LLaMA, the author proposes the following methods: 1. **Multi - token prediction**: Predict multiple tokens in a single decoding step, thereby reducing the number of required decoding steps. Specifically, reduce the sequence of length \( U \) that originally required \( U \) steps to generate to \( \frac{U}{K} \) steps, where \( K \) is the number of tokens predicted each time. 2. **Model architecture improvement**: - **Independent projection heads**: Use multiple independent projection heads to calculate the probabilities of multiple tokens in parallel. - **Latent space expansion**: By decomposing each projection head into a full - rank matrix and a shared un - embedded matrix, reduce the number of additional parameters and make the model more compact. 3. **Inference strategy**: - **Threshold selection**: By setting a threshold \( \tau \), select multiple tokens that meet the conditions. - **Verification selection**: Combine the prediction and verification steps to ensure that the generated sequence is consistent with the result of autoregressive decoding. 4. **Training objective**: - Expand the cross - entropy loss function to cover all \( K \) predictions. - Use the Minimum Word Error Rate (MWER) for sequence discriminative training to improve the robustness of the model. ### Experimental results Through experiments, the author shows the performance of the proposed method on multiple public benchmark datasets. The results show that the multi - token prediction method can significantly reduce the number of decoder invocations (about 3.2 times) while maintaining or improving ASR performance, thereby accelerating the inference process. ### Summary The main contribution of this paper is that by introducing the multi - token prediction technique, it effectively solves the problem of slow inference speed of the Speech - LLaMA model in the ASR task while maintaining good recognition performance.