Probabilistic Inference in Language Models via Twisted Sequential Monte Carlo

Stephen Zhao,Rob Brekelmans,Alireza Makhzani,Roger Grosse
2024-04-27
Abstract:Numerous capability and safety techniques of Large Language Models (LLMs), including RLHF, automated red-teaming, prompt engineering, and infilling, can be cast as sampling from an unnormalized target distribution defined by a given reward or potential function over the full sequence. In this work, we leverage the rich toolkit of Sequential Monte Carlo (SMC) for these probabilistic inference problems. In particular, we use learned twist functions to estimate the expected future value of the potential at each timestep, which enables us to focus inference-time computation on promising partial sequences. We propose a novel contrastive method for learning the twist functions, and establish connections with the rich literature of soft reinforcement learning. As a complementary application of our twisted SMC framework, we present methods for evaluating the accuracy of language model inference techniques using novel bidirectional SMC bounds on the log partition function. These bounds can be used to estimate the KL divergence between the inference and target distributions in both directions. We apply our inference evaluation techniques to show that twisted SMC is effective for sampling undesirable outputs from a pretrained model (a useful component of harmlessness training and automated red-teaming), generating reviews with varied sentiment, and performing infilling tasks.
Machine Learning,Artificial Intelligence,Computation and Language
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: how to effectively perform probabilistic reasoning in large language models (LLMs), especially by sampling from an unnormalized target distribution defined by a given reward or potential function. Specifically, the paper proposes using the Twisted Sequential Monte Carlo (SMC) method to address this challenge. ### Problem Background In large language models, many capabilities and safety techniques (such as Reinforcement Learning from Human Feedback (RLHF), automatic red - team testing, prompt engineering, and filling tasks) can be regarded as sampling from an unnormalized target distribution defined by a reward or potential function over the entire sequence. The core of these tasks lies in how to guide the model to generate text that satisfies specific properties, for example: - **RLHF**: Adjust the generated content according to human feedback. - **Automatic red - team testing**: Generate undesired outputs for harmlessness training. - **Reasoning tasks**: Generate outputs that may be considered valid by verifiers. ### Specific Problems 1. **Sampling from non - causal target distributions**: Since the target distribution is non - causal, it is very difficult to sample directly from it. It is necessary to estimate the marginal distribution at each time step, which involves intractable marginalization operations. 2. **Evaluating reasoning quality**: Effective methods are required to evaluate the quality of language model reasoning techniques, especially to estimate the KL divergence between the inferred distribution and the target distribution. ### Solutions The paper proposes a new framework - Twisted Sequential Monte Carlo (Twisted SMC), and solves the problems in the following ways: 1. **Learning the twisting function**: The twisting function \(\psi_t(s_{1:t})\) is introduced to modulate the base model \(p_0(s_{1:t})\) so that it matches the target marginal distribution \(\sigma(s_{1:t})\). These twisting functions can focus on promising partial sequences at each time step. \[ \pi_t(s_{1:t})=\frac{1}{Z_\psi}p_0(s_{1:t})\psi_t(s_{1:t}) \] 2. **Contrastive Twisting Learning (CTL)**: A new contrastive twisting learning method is developed to learn intermediate twisting functions. The CTL method matches the approximate target marginal distribution by minimizing T individual KL divergences: \[ L_{CTL}(\theta)=\sum_{t = 1}^T D_{KL}(\sigma(s_{1:t})\|\pi_\theta(s_{1:t})) \] 3. **Evaluating reasoning quality**: The use of bidirectional SMC bounds is proposed to evaluate the effectiveness of language model reasoning techniques. These bounds can be used to estimate the symmetric KL divergence between the inferred distribution and the target distribution. ### Application Examples The paper shows the application of Twisted SMC in the following aspects: - **Generating reviews with different emotions**: By controlling the potential function \(\phi(s_{1:T})\), book reviews with different emotional tendencies can be generated. - **Performing filling tasks**: Given a partial sequence, generate reasonable subsequent content. - **Harmlessness training**: Generate undesired outputs to help improve the safety and harmlessness of the model. In summary, this paper aims to provide a general probabilistic reasoning framework through Twisted SMC to improve the performance of language models in various tasks and provide effective evaluation tools.