Nudging: Inference-time Alignment via Model Collaboration

Yu Fei,Yasaman Razeghi,Sameer Singh
2024-10-15
Abstract:Large language models (LLMs) require alignment, such as instruction-tuning or reinforcement learning from human feedback, to effectively and safely follow user instructions. This process necessitates training aligned versions for every model size in each model family, resulting in significant computational overhead. In this work, we propose nudging, a simple, plug-and-play, and training-free algorithm that aligns any base model at inference time using a small aligned model. Nudging is motivated by recent findings that alignment primarily alters the model's behavior on a small subset of stylistic tokens, such as "Sure" or "Thank". We find that base models are significantly more uncertain when generating these tokens. Leveraging this observation, nudging employs a small aligned model to generate nudging tokens to steer the large base model's output toward desired directions when the base model's uncertainty is high. We evaluate the effectiveness of nudging across 3 model families and 13 tasks, covering reasoning, general knowledge, instruction following, and safety benchmarks. Without any additional training, nudging a large base model with a 7x - 14x smaller aligned model achieves zero-shot performance comparable to, and sometimes surpassing, that of large aligned models. For example, nudging OLMo-7b with OLMo-1b-instruct, affecting less than 9% of tokens, achieves a 10% absolute improvement on GSM8K over OLMo-7b-instruct. Unlike prior inference-time tuning methods, nudging enables off-the-shelf collaboration between model families. For instance, nudging Gemma-2-27b with Llama-2-7b-chat outperforms Llama-2-70b-chat on various tasks. Overall, this work introduces a simple yet powerful approach to token-level model collaboration, offering a modular solution to LLM alignment. Our project website: <a class="link-external link-https" href="https://fywalter.github.io/nudging/" rel="external noopener nofollow">this https URL</a> .
Computation and Language,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to align large language models (LLMs) during inference without an additional training process. Traditional alignment methods, such as instruction tuning or reinforcement learning from human feedback, require separate alignment tuning for each model size in each model family, which leads to significant computational overhead. This paper proposes a simple, plug - and - play, and training - free algorithm named "Nudging", which can align any base model using a small alignment model during inference. Nudging achieves this by identifying the uncertainty of the base model when generating certain specific words and using the small alignment model to generate "nudging words" to guide the output direction of the large model. This method not only reduces the computational burden but also improves the flexibility and adaptability of the model. Specifically, the main contributions of the paper include: 1. **Proposing the Nudging algorithm**: A method for aligning large language models during inference without additional training. 2. **Analyzing the impact of alignment**: By analyzing the changes in the vocabulary distribution of the model before and after alignment, it is found that alignment mainly affects the behavior of the model when generating certain specific words. 3. **Experimentally verifying the effectiveness**: Extensive experiments have been carried out on multiple model families and tasks, proving that Nudging can be on a par with or even outperform the aligned large models in zero - sample performance. 4. **Effectiveness across model families**: Nudging is not only effective within the same model family but also enables effective collaboration between models from different model families. In conclusion, this paper provides a new, efficient, and flexible solution for aligning large language models during inference, thereby reducing the computational overhead brought by traditional alignment methods.