Mitigating the Influence of Distractor Tasks in LMs with Prior-Aware Decoding

Raymond Douglas,Andis Draguns,Tomáš Gavenčiak
2024-10-15
Abstract:The broad capabilities of Language Models (LMs) can be limited by their sensitivity to distractor tasks: LMs can infer secondary tasks from the prompt in addition to the intended one, leading to unwanted outputs. For example, prompt injection attacks can cause models to deviate from explicit directives. In some 'inverse scaling' cases, this unwanted behaviour actually worsens as models scale up to at least 540B parameters. We present a theoretical framework that interprets LMs as a product of experts that combine multiple data generation processes. Based on this framework, we demonstrate prior-aware decoding (PAD) - a simple contrastive inference method to reduce the influence of distractor tasks. We apply PAD to eleven models, across four datasets, and find improvements in 41 out of 44 task-model combinations, with a median increase in task completion proportion of 40%. The results suggest a promising direction for further development towards more reliable language models.
Computation and Language
What problem does this paper attempt to address?
The problem that this paper attempts to solve is that language models (LMs) are vulnerable to the influence of distractor tasks when processing tasks. Specifically, when given a prompt containing a primary task and a secondary task, the language model may prioritize the secondary task, resulting in undesired outputs. For example, prompt - injection attacks can cause the model to deviate from explicit instructions. More seriously, in some "reverse - scaling" cases, this bad behavior intensifies as the model parameter scale increases, even when the model parameters reach 54 billion. The paper proposes a theoretical framework, which interprets the language model as an expert combination of multiple data - generation processes, and based on this framework, introduces the Prior - Aware Decoding (PAD) method to reduce the influence of distractor tasks. ### Main Problems 1. **Influence of Distractor Tasks**: - When processing tasks, the language model may infer secondary tasks from the prompt, resulting in undesired outputs. - Prompt - injection attacks can cause the model to deviate from explicit instructions. - In some cases, as the model parameters increase, this bad behavior intensifies instead, which is the so - called "reverse - scaling" phenomenon. 2. **Model Reliability**: - As language models are widely used in various tasks, their reliability and potential vulnerabilities have become the focus of attention. - When the model processes complex or simple tasks, it may be interfered by common patterns, resulting in wrong outputs. ### Solutions The paper proposes the Prior - Aware Decoding (PAD) method to reduce the influence of distractor tasks through the following steps: 1. **Generate Two Versions of Prompts**: - Original prompt: contains task descriptions and data. - Weakened prompt: more likely to produce outputs consistent with distractor tasks. 2. **Query the Model**: - Use the original prompt and the weakened prompt to query the model respectively and obtain two sets of logits. 3. **Linearly Combine Logits**: - Calculate the linear combination of the two sets of logits: \[ L = L_O+\alpha(L_O - L_W) \] - \( L_O \) is the original logit. - \( L_W \) is the weakened logit. - \( \alpha \) is the extrapolation parameter. 4. **Generate Output**: - Sample from the modified distribution to generate output. 5. **Evaluate Performance**: - Calculate the average performance on the entire dataset. ### Experimental Results - Applying the PAD method on 11 models and 4 datasets, the results show that 41 out of 44 task - model combinations are improved. - The average task completion ratio is increased by 40%. - Especially when \(\alpha = 2\), the model's task completion ratio doubles on many tasks. ### Contributions 1. **Theoretical Framework**: - Interpret the language model as a geometric mixture model (i.e., an expert combination) to understand how distractor tasks lead to poor performance. 2. **Method Innovation**: - Propose a general framework for extracting other components in the geometric mixture model and deriving a more ideal mixture in the log - space. 3. **Experimental Verification**: - Prove through experiments that the PAD method can significantly improve performance on multiple language models and tasks, especially in the presence of distractor tasks. In conclusion, by introducing the Prior - Aware Decoding method, this paper effectively reduces the influence of distractor tasks on language models when processing tasks and improves the reliability and performance of the models.