Context-Guided Diffusion for Out-of-Distribution Molecular and Protein Design

Leo Klarner,Tim G. J. Rudner,Garrett M. Morris,Charlotte M. Deane,Yee Whye Teh
2024-07-17
Abstract:Generative models have the potential to accelerate key steps in the discovery of novel molecular therapeutics and materials. Diffusion models have recently emerged as a powerful approach, excelling at unconditional sample generation and, with data-driven guidance, conditional generation within their training domain. Reliably sampling from high-value regions beyond the training data, however, remains an open challenge -- with current methods predominantly focusing on modifying the diffusion process itself. In this paper, we develop context-guided diffusion (CGD), a simple plug-and-play method that leverages unlabeled data and smoothness constraints to improve the out-of-distribution generalization of guided diffusion models. We demonstrate that this approach leads to substantial performance gains across various settings, including continuous, discrete, and graph-structured diffusion processes with applications across drug discovery, materials science, and protein design.
Biomolecules,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to reliably generate samples beyond the scope of training data from high - value regions in molecular and protein design. Specifically, although existing diffusion models perform well in unconditional sample generation and conditional generation under data - driven guidance, they still face challenges in reliably sampling from high - value regions outside the training data. Current methods mainly focus on modifying the diffusion process itself, and this paper proposes a new method - Context - Guided Diffusion (CGD), which improves the generalization ability of the guided diffusion model by using unlabeled data and smoothness constraints, especially in dealing with Out - of - Distribution (OOD) data. ### Main problems 1. **Insufficient generalization ability**: When dealing with out - of - distribution data, existing methods often produce inaccurate or over - confident predictions due to the limitations of training data, resulting in a decline in the quality of generated samples. 2. **Insufficient sample diversity**: In molecular and protein design, it is necessary to explore new chemical spaces, but existing methods have limitations in generating diverse and innovative samples. 3. **Reliability of conditional generation**: When generating molecules and proteins with specific properties, reliable conditional generation methods are required to ensure that the generated samples have the expected functions and properties. ### Solutions The Context - Guided Diffusion (CGD) method proposed in this paper solves the above problems in the following ways: - **Using unlabeled data**: By introducing unlabeled data, CGD can better capture the structural information of the input domain, thus performing better on out - of - distribution data. - **Smoothness constraints**: By adding smoothness constraints during the training process, CGD can generate more stable and reliable gradient estimates and avoid over - fitting to noisy data. - **High - uncertainty estimation**: In out - of - distribution regions, CGD encourages the model to produce high - uncertainty predictions, thereby avoiding generating incorrect samples. ### Experimental verification The authors verified the effectiveness of CGD through multiple experiments, including small - molecule design, new - material generation, and discrete protein - sequence optimization. The experimental results show that CGD can significantly improve the performance of the conditional - generation process in different types of diffusion models and application fields, and generate more novel samples with desired properties. ### Formulas and technical details - **Forward process of the diffusion model**: \[ dX_t = f(X_t, t)+g(t)dB_t \] where \(X_0\sim p_0\) is the initial sample sampled from the data distribution, and \(B_t\) represents a d - dimensional Brownian motion. - **Reverse denoising process**: \[ dX_t = -\beta_t\left(\frac{1}{2}X_t+\nabla\log p_t(X_t)\right)dt+\sqrt{\beta_t}dB_t \] where \(\nabla\log p_t(X_t)\) is the gradient of the log - density, which is estimated by the time - dependent score network \(s_\psi(x_t, t)\). - **Conditional guidance function**: \[ \nabla\log p_t(X_t|y)=\nabla\log p_t(y|X_t)+\nabla\log p_t(X_t) \] - **Regularization term**: \[ R(\theta, f_t, t, p_{\hat{X}_t})=\mathbb{E}_{p_{\hat{X}_t}}\left[\sum_{j = 1}^2 DM(f_j^t(\hat{x}_t;\theta), m_j^t(\hat{x}_t), K_t(\hat{x}_t))^2\right] \] where \(DM(f_j^t(\hat{x}_t;\theta), m_j^t(\hat{x}_t), K_t(\hat{x}_t))^2\) is the squared Mahalanobis distance, which is used to measure