Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
Thermodynamics + molecular pretraining = accurate pKa prediction! A Uni-pKa inference demo
pka
English
Machine Learning
pkaEnglishMachine Learning
Weiliang Luo
chenx@dp.tech
zhougm@dp.tech
更新于 2024-11-11
推荐镜像 :Basic Image:ubuntu:22.04-py3.10-pytorch2.0
推荐机型 :c2_m4_cpu
Uni-pKa trained weight(v2)

©️ Copyright 2024 @ Authors
📖 Getting Started Guide
Licensing Agreement: This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
This document can be executed directly on the Bohrium Notebook. To begin, click the Connect button located at the top of the interface. We have already set up the recommended image ubuntu:22.04-py3.10-pytorch2.0 and the recommended machine type c2_m4_cpu for you.

代码
文本

Uni-pKa is a pKa prediction framework published in the article Bridging Machine Learning and Thermodynamics for Accurate pKa Prediction.

代码
文本

Implementation of Uni-pKa model

Unfold the hidden blocks if you're interested in the implementation details, otherwise please click the "run all" button to initialize everything at your first run.

Loading libraries

代码
文本
[3]
!pip install rdkit
!pip install matplotlib
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: rdkit in /opt/mamba/lib/python3.10/site-packages (2024.3.6)
Requirement already satisfied: Pillow in /opt/mamba/lib/python3.10/site-packages (from rdkit) (11.0.0)
Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from rdkit) (1.24.2)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: matplotlib in /opt/mamba/lib/python3.10/site-packages (3.9.2)
Requirement already satisfied: numpy>=1.23 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (1.24.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (1.3.0)
Requirement already satisfied: python-dateutil>=2.7 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: pillow>=8 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (11.0.0)
Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (23.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (4.54.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (1.4.7)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/mamba/lib/python3.10/site-packages (from matplotlib) (3.2.0)
Requirement already satisfied: six>=1.5 in /opt/mamba/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[4]
import os, sys, logging, warnings, argparse
from multiprocessing import Pool
from typing import Optional
from tqdm import tqdm
import numpy as np
from scipy.spatial import distance_matrix
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings(action='ignore')
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("unimol_free_energy.inference")
logging.disable(50)
代码
文本

Dictionary class

代码
文本
已隐藏单元格
代码
文本

Atom type and charge dictionary

代码
文本
已隐藏单元格
代码
文本

3D conformational generation

代码
文本
已隐藏单元格
代码
文本

Transformer backbone

代码
文本
已隐藏单元格
代码
文本

Uni-Mol model

代码
文本
已隐藏单元格
代码
文本

Molecular dataset

代码
文本
已隐藏单元格
代码
文本

Interface for free energy inference

代码
文本
已隐藏单元格
代码
文本

Micro pKa prediction

Load model weights and initialize the predictor.

Important Note: Here provided is a single model weight finetuned on Dwar-iBonD and Novartis datasets for general inference purpose, and the predicted results may slightly differ from the article, which are predicted by a 5-fold ensemble finetuned on only Dwar-iBonD dataset.

代码
文本
[13]
model_path = "/bohr/uni-pka-ckpt-ancf/v2/t_dwar_v_novartis_a_b.pt"
predictor = FreeEnergyPredictor(model_path)
代码
文本

Protonation/Deprotonation function for a molecule given the index of the protonated/deprotonated atom.

代码
文本
[14]
from rdkit.Chem import Mol, RWMol, AddHs, SanitizeMol, MolToSmiles, MolFromSmiles
from rdkit.Chem.Draw import MolToImage
from PIL import Image

def prot(mol: Mol, idx: int, mode: str) -> Mol:
'''
Protonate / Deprotonate a molecule at a specified site

Params:
----
`mol`: Molecule

`idx`: Index of reaction

`mode`: `a2b` means deprotonization, with a hydrogen atom or a heavy atom at `idx`; `b2a` means protonization, with a heavy atom at `idx`

Return:
----
`mol_prot`: (De)protonated molecule
'''
mw = RWMol(mol)
if mode == "a2b":
atom_H = mw.GetAtomWithIdx(idx)
if atom_H.GetAtomicNum() == 1:
atom_A = atom_H.GetNeighbors()[0]
charge_A = atom_A.GetFormalCharge()
atom_A.SetFormalCharge(charge_A - 1)
mw.RemoveAtom(idx)
mol_prot = mw.GetMol()
else:
charge_H = atom_H.GetFormalCharge()
numH_H = atom_H.GetTotalNumHs()
atom_H.SetFormalCharge(charge_H - 1)
atom_H.SetNumExplicitHs(numH_H - 1)
atom_H.UpdatePropertyCache()
mol_prot = AddHs(mw)
elif mode == "b2a":
atom_B = mw.GetAtomWithIdx(idx)
charge_B = atom_B.GetFormalCharge()
atom_B.SetFormalCharge(charge_B + 1)
numH_B = atom_B.GetNumExplicitHs()
atom_B.SetNumExplicitHs(numH_B + 1)
mol_prot = AddHs(mw)
SanitizeMol(mol_prot)
mol_prot = MolFromSmiles(MolToSmiles(mol_prot, canonical=False))
mol_prot = AddHs(mol_prot)
return mol_prot

def draw(mol: Mol, size=(300, 300), highlightAtoms=[]) -> Image:
for atom in mol.GetAtoms():
atom.SetProp("atomNote", str(atom.GetIdx()))
return MolToImage(mol, size=size, highlightAtoms=highlightAtoms, highlightColor=(0.8,0.8,0.8))
代码
文本

A glycine with atom indices

代码
文本
[15]
from IPython.display import display

smi = "NCC(=O)O"
mol = MolFromSmiles(smi)
display(draw(mol))
代码
文本

Deprotonate the atom 4 in the glycine (oxygen of carboxylic group)

代码
文本
[16]
from rdkit.Chem import RemoveHs
mol_deprot = RemoveHs(prot(mol, 4, "a2b"))
smi_deprot = MolToSmiles(mol_deprot)
display(draw(mol_deprot, highlightAtoms=[4]))
代码
文本

Protonate the atom 0 in the glycine (nitrogen of amino group)

代码
文本
[17]
mol_prot = RemoveHs(prot(mol, 0, "b2a"))
smi_prot = MolToSmiles(mol_prot)
display(draw(mol_prot, highlightAtoms=[0]))
代码
文本

Deprotonating the carboxylic acid in the protonated glycine or Protonating the amino group in the deprotonated glycine converges to the zwitter ion form

代码
文本
[18]
mol_zwitter = RemoveHs(prot(mol_prot, 4, "a2b"))
smi_zwitter = MolToSmiles(mol_zwitter)
display(draw(mol_zwitter))
display(draw(RemoveHs(prot(mol_deprot, 0, "b2a"))))
代码
文本

Micro-pKa prediction function

代码
文本
[19]
import math
from typing import Literal
LN10 = math.log(10)
TRANSLATE_PH = 6.504894871171601

from rdkit.Chem.rdChemReactions import ReactionFromSmarts
from rdkit.Chem.Draw import ReactionToImage

def predict_micro_pKa(smi: str, idx: int, mode: Literal["a2b", "b2a"]):
mol = MolFromSmiles(smi)
new_mol = RemoveHs(prot(mol, idx, mode))
new_smi = MolToSmiles(new_mol)
if mode == "a2b":
smi_A = smi
smi_B = new_smi
elif mode == "b2a":
smi_B = smi
smi_A = new_smi
DfGm = predictor.predict([smi_A, smi_B])
pKa = (DfGm[smi_B] - DfGm[smi_A]) / LN10 + TRANSLATE_PH
pKa_image = ReactionToImage(ReactionFromSmarts(f"{smi_A}>>{smi_B}", useSmiles=True))
display(pKa_image)
return pKa
代码
文本

Predict all Micro-pKa between 4 protonation states of glycine above. Check the consistency of the thermodynamic cycle of two routes of deprotonation through the non-charged form or the zwitterion form.

代码
文本
[20]
from IPython.display import display_markdown, display_latex

micro_pKa_1 = predict_micro_pKa(smi_prot, 0, "a2b")
display_latex(f"1-H<sub>2</sub>A<sup>+</sup> → 1-HA, pK<sub>a1</sub>: {micro_pKa_1:.2f}", raw=True)
micro_pKa_2 = predict_micro_pKa(smi, 4, "a2b")
display_latex(f"1-HA → 1-A<sup>-</sup>, pK<sub>a2</sub>: {micro_pKa_2:.2f}", raw=True)
micro_pKa_3 = predict_micro_pKa(smi_prot, 4, "a2b")
display_latex(f"1-H<sub>2</sub>A<sup>+</sup> → 2-HA, p*K*<sub>a3</sub>: {micro_pKa_3:.2f}", raw=True)
micro_pKa_4 = predict_micro_pKa(smi_zwitter, 0, "a2b")
display_latex(f"2-HA → 1-A<sup>-</sup>, pK<sub>a4</sub>: {micro_pKa_4:.2f}", raw=True)
display_latex(f"pK<sub>a1</sub> + pK<sub>a2</sub> = {micro_pKa_1:.2f} + {micro_pKa_2:.2f} = {micro_pKa_1 + micro_pKa_2:.2f}", raw=True)
display_latex(f"pK<sub>a3</sub> + pK<sub>a4</sub> = {micro_pKa_3:.2f} + {micro_pKa_4:.2f} = {micro_pKa_3 + micro_pKa_4:.2f}", raw=True)
1-H2A+ → 1-HA, pKa1: 7.17
1-HA → 1-A-, pKa2: 4.64
1-H2A+ → 2-HA, p*K*a3: 2.27
2-HA → 1-A-, pKa4: 9.53
pKa1 + pKa2 = 7.17 + 4.64 = 11.81
pKa3 + pKa4 = 2.27 + 9.53 = 11.81
代码
文本

Microstate Enumerator

Enumeration function that starts from a single protonation state and ends with the whole macrostate of it and the one after its protonation/deprotonation, given a ionization site template.

代码
文本
[23]
from typing import List, Tuple, Union, Callable
from collections import OrderedDict
import pandas as pd
from rdkit.Chem import CanonSmiles, MolFromSmarts


FILTER_PATTERNS = list(map(MolFromSmarts, [
"[#6X5]",
"[#7X5]",
"[#8X4]",
"[*r]=[*r]=[*r]",
"[#1]-[*+1]~[*-1]",
"[#1]-[*+1]=,:[*]-,:[*-1]",
"[#1]-[*+1]-,:[*]=,:[*-1]",
"[*+2]",
"[*-2]",
"[#1]-[#8+1].[#8-1,#7-1,#6-1]",
"[#1]-[#7+1,#8+1].[#7-1,#6-1]",
"[#1]-[#8+1].[#8-1,#6-1]",
"[#1]-[#7+1].[#8-1]-[C](-[C,#1])(-[C,#1])",
# "[#6;!$([#6]-,:[*]=,:[*]);!$([#6]-,:[#7,#8,#16])]=[C](-[O,N,S]-[#1])",
# "[#6]-,=[C](-[O,N,S])(-[O,N,S]-[#1])",
"[OX1]=[C]-[OH2+1]",
"[NX1,NX2H1,NX3H2]=[C]-[O]-[H]",
"[#6-1]=[*]-[*]",
"[cX2-1]",
"[N+1](=O)-[O]-[H]"
]))


def sanitize_checker(smi: str, filter_patterns: List[Mol], verbose: bool=False) -> bool:
"""
Check if a SMILES can be sanitized and does not contain unreasonable chemical structures.

Params:
----
`smi`: The SMILES to be check.

`filter_patterns`: Unreasonable chemical structures.

`verbose`: If True, matched unreasonable chemical structures will be printed.

Return:
----
If the SMILES should be filtered.
"""
mol = AddHs(MolFromSmiles(smi))
for pattern in filter_patterns:
match = mol.GetSubstructMatches(pattern)
if match:
if verbose:
print(f"pattern {pattern}")
return False
try:
SanitizeMol(mol)
except:
print("cannot sanitize")
return False
return True


def sanitize_filter(smis: List[str], filter_patterns: List[Mol]=FILTER_PATTERNS) -> List[str]:
"""
A filter for SMILES can be sanitized and does not contain unreasonable chemical structures.

Params:
----
`smis`: The list of SMILES.

`filter_patterns`: Unreasonable chemical structures.

Return:
----
The list of SMILES filtered.
"""
def _checker(smi):
return sanitize_checker(smi, filter_patterns)
return list(filter(_checker, smis))


def cnt_stereo_atom(smi: str) -> int:
"""
Count the stereo atoms in a SMILES
"""
mol = MolFromSmiles(smi)
return sum([str(atom.GetChiralTag()) != "CHI_UNSPECIFIED" for atom in mol.GetAtoms()])


def stereo_filter(smis: List[str]) -> List[str]:
"""
A filter against SMILES losing stereochemical information in structure processing.
"""
filtered_smi_dict = dict()
for smi in smis:
nonstereo_smi = CanonSmiles(smi, useChiral=0)
stereo_cnt = cnt_stereo_atom(smi)
if nonstereo_smi not in filtered_smi_dict:
filtered_smi_dict[nonstereo_smi] = (smi, stereo_cnt)
else:
if stereo_cnt > filtered_smi_dict[nonstereo_smi][1]:
filtered_smi_dict[nonstereo_smi] = (smi, stereo_cnt)
return [value[0] for value in filtered_smi_dict.values()]


def make_filter(filter_param: OrderedDict) -> Callable:
"""
Make a sequential SMILES filter

Params:
----
`filter_param`: An `collections.OrderedDict` whose keys are single filter functions and the corresponding values are their parameter dictionary.

Return:
----
The sequential filter function
"""
def seq_filter(smis):
for single_filter, param in filter_param.items():
smis = single_filter(smis, **param)
return smis
return seq_filter


def match_template(template: pd.DataFrame, mol: Mol, verbose: bool=False) -> list:
'''
Find protonation site using templates

Params:
----
`template`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags

`mol`: Molecule

`verbose`: Boolean flag for printing matching results

Return:
----
A set of matched indices to be (de)protonated
'''
mol = AddHs(mol)
matches = []
for idx, name, smarts, index, acid_base in template.itertuples():
pattern = MolFromSmarts(smarts)
match = mol.GetSubstructMatches(pattern)
if len(match) == 0:
continue
else:
index = int(index)
for m in match:
matches.append(m[index])
if verbose:
print(f"find index {m[index]} in pattern {name} smarts {smarts}")
return list(set(matches))


def prot_template(template: pd.DataFrame, smi: str, mode: str) -> Tuple[List[int], List[str]]:
"""
Protonate / Deprotonate a SMILES at every found site in the template

Params:
----
`template`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags

`smi`: The SMILES to be processed

`mode`: `a2b` means deprotonization, with a hydrogen atom or a heavy atom at `idx`; `b2a` means protonization, with a heavy atom at `idx`
"""
mol = MolFromSmiles(smi)
sites = match_template(template, mol)
smis = []
for site in sites:
smis.append(CanonSmiles(MolToSmiles(RemoveHs(prot(mol, site, mode)))))
return sites, list(set(smis))


def enumerate_template(smi: Union[str, List[str]], template_a2b: pd.DataFrame, template_b2a: pd.DataFrame, mode: str="A", maxiter: int=10, verbose: int=0, filter_patterns: List[Mol]=FILTER_PATTERNS) -> Tuple[List[str], List[str]]:
"""
Enumerate all the (de)protonation results of one SMILES.

Params:
----
`smi`: The smiles to be processed.

`template_a2b`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, deprotonation indices and acid flags.

`template_b2a`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and base flags.

`mode`:
- "a2b": `smi` is an acid to be deprotonated.
- "b2a": `smi` is a base to be protonated.

`maxiter`: Max iteration number of template matching and microstate pool growth.

`verbose`:
- 0: Silent mode.
- 1: Print the length of microstate pools in each iteration.
- 2: Print the content of microstate pools in each iteration.

`filter_patterns`: Unreasonable chemical structures.

Return:
----
A microstate pool and B microstate pool after enumeration.
"""
if isinstance(smi, str):
smis = [smi]
else:
smis = list(smi)

enum_func = lambda x: [x] # TODO: Tautomerism enumeration

if mode == "a2b":
smis_A_pool, smis_B_pool = smis, []
elif mode == "b2a":
smis_A_pool, smis_B_pool = [], smis
filters = make_filter({
sanitize_filter: {"filter_patterns": filter_patterns},
stereo_filter: {}
})
pool_length_A = -1
pool_length_B = -1
i = 0
while (len(smis_A_pool) != pool_length_A or len(smis_B_pool) != pool_length_B) and i < maxiter:
pool_length_A, pool_length_B = len(smis_A_pool), len(smis_B_pool)
if verbose > 0:
print(f"iter {i}: {pool_length_A} acid, {pool_length_B} base")
if verbose > 1:
print(f"iter {i}, acid: {smis_A_pool}, base: {smis_B_pool}")
if (mode == "a2b" and (i + 1) % 2) or (mode == "b2a" and i % 2):
smis_A_tmp_pool = []
for smi in smis_A_pool:
smis_B_pool += filters(prot_template(template_a2b, smi, "a2b")[1])
smis_A_tmp_pool += filters([CanonSmiles(MolToSmiles(mol)) for mol in enum_func(MolFromSmiles(smi))])
smis_A_pool += smis_A_tmp_pool
elif (mode == "b2a" and (i + 1) % 2) or (mode == "a2b" and i % 2):
smis_B_tmp_pool = []
for smi in smis_B_pool:
smis_A_pool += filters(prot_template(template_b2a, smi, "b2a")[1])
smis_B_tmp_pool += filters([CanonSmiles(MolToSmiles(mol)) for mol in enum_func(MolFromSmiles(smi))])
smis_B_pool += smis_B_tmp_pool
smis_A_pool = filters(smis_A_pool)
smis_B_pool = filters(smis_B_pool)
smis_A_pool = list(set(smis_A_pool))
smis_B_pool = list(set(smis_B_pool))
i += 1
if verbose > 0:
print(f"iter {i}: {pool_length_A} acid, {pool_length_B} base")
if verbose > 1:
print(f"iter {i}, acid: {smis_A_pool}, base: {smis_B_pool}")
smis_A_pool = list(map(CanonSmiles, smis_A_pool))
smis_B_pool = list(map(CanonSmiles, smis_B_pool))
return smis_A_pool, smis_B_pool
代码
文本

We define the minimal ionizaton template for our amino acid example, which only includes the ionization of carboxylic acids and amines.

代码
文本
[24]
template_a2b = pd.DataFrame([
{"substructure": "Carboxylic acid", "SMARTS": "[$([#6]=[#8]):0]-[OX2:1]-[H:2]", "Index": 1, "Acid_or_base": "A"},
{"substructure": "Amine", "SMARTS": "[NX4+1:0]", "Index": 0, "Acid_or_base": "A"}
])
template_b2a = pd.DataFrame([
{"substructure": "Carboxylic acid", "SMARTS": "[$([#6]=[#8]):0]-[O-1:1]", "Index": 1, "Acid_or_base": "B"},
{"substructure": "Amine", "SMARTS": "[NX3:0]", "Index": 0, "Acid_or_base": "B"}
])
代码
文本

This is a glutamic acid.

代码
文本
[25]
smi_GLU = "NC(CCC(=O)O)C(=O)O"
display(draw(MolFromSmiles(smi_GLU)))
代码
文本

Enumerate the macrostates of the glutamic acid and its deprotonated form.

代码
文本
[26]
macrostate_A, macrostate_B = enumerate_template(smi_GLU, template_a2b, template_b2a, mode="a2b")
代码
文本

Drawing a macrostate with microstate indices

代码
文本
[27]
from rdkit.Chem.Draw import MolsToGridImage

def draw_macrostate(macrostate: List[str], base_name: str):
macrostate_mols = list(map(MolFromSmiles, macrostate))
macrostate_size = len(macrostate_mols)
legends = [f"{i+1}-{base_name}" for i in range(macrostate_size)]
display(MolsToGridImage(macrostate_mols, legends=legends, useSVG=True))
代码
文本
[28]
draw_macrostate(macrostate_A, "H<sub>2</sub>A")
draw_macrostate(macrostate_B, "HA<sup>-</sup>")
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
代码
文本

Continue to generate the fully protonated macrostate and the fully deprotonated one

代码
文本
[29]
macrostate_AA, _ = enumerate_template(macrostate_A, template_a2b, template_b2a, mode="b2a")
draw_macrostate(macrostate_AA, "H<sub>3</sub>A<sup>+</sup>")
_, macrostate_BB = enumerate_template(macrostate_B, template_a2b, template_b2a, mode="a2b")
draw_macrostate(macrostate_BB, "A<sup>2-</sup>")
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
代码
文本

Macro pKa prediction

Macro-pKa prediction function

代码
文本
[31]
def log_sum_exp(DfGm: List[float]) -> float:
return math.log10(sum([math.exp(-g) for g in DfGm]))


def predict_macro_pKa(smi: str, template_a2b: pd.DataFrame, template_b2a: pd.DataFrame, mode: Literal["a2b", "b2a"]) -> float:
macrostate_A, macrostate_B = enumerate_template(smi, template_a2b, template_b2a, mode)
DfGm_A = predictor.predict(macrostate_A)
DfGm_B = predictor.predict(macrostate_B)
draw_macrostate(macrostate_A, "A")
draw_macrostate(macrostate_B, "B")
return log_sum_exp(DfGm_A.values()) - log_sum_exp(DfGm_B.values()) + TRANSLATE_PH
代码
文本

Predict all macro-pKa of the glutamic acid, and show the corresponding acid/basic macrostates.

代码
文本
[32]
macro_pKa_1 = predict_macro_pKa(smi_GLU, template_a2b, template_b2a, "b2a")
display_latex(f"H<sub>3</sub>A<sup>+</sup> → H<sub>2</sub>A, pK<sub>a1</sub>: {macro_pKa_1:.2f}", raw=True)
macro_pKa_2 = predict_macro_pKa(smi_GLU, template_a2b, template_b2a, "a2b")
display_latex(f"H<sub>2</sub>A → HA<sup>-</sup>, pK<sub>a2</sub>: {macro_pKa_2:.2f}", raw=True)
macro_pKa_3 = predict_macro_pKa(macrostate_B, template_a2b, template_b2a, "a2b")
display_latex(f"HA<sup>-</sup> → A<sup>2-</sup>, pK<sub>a2</sub>: {macro_pKa_3:.2f}", raw=True)
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
H3A+ → H2A, pKa1: 1.96
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
H2A → HA-, pKa2: 4.54
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
HA- → A2-, pKa2: 9.19
代码
文本

Distribution fraction prediction

Standardized microstate name calculation function.

代码
文本
[36]
def calc_base_name(neutral_base_name: str, target_charge: int) -> str:
if neutral_base_name.startswith("H"):
if neutral_base_name[1:].startswith("<sub>"):
num_H_end = neutral_base_name.find("</sub>", 6)
num_H = int(neutral_base_name[6:num_H_end])
else:
num_H_end = 1
num_H = 1
else:
num_H_end = 0
num_H = 0
target_num_H = num_H + target_charge
assert target_num_H >= 0
target_base_name = ""
if target_num_H == 1:
target_base_name += "H"
elif target_num_H > 1:
target_base_name += f"H<sub>{target_num_H}</sub>"
target_base_name += "A"
if target_charge < -1:
target_base_name += f"<sup>{-target_charge}-</sup>"
elif target_charge == -1:
target_base_name += f"<sup>-</sup>"
elif target_charge == 1:
target_base_name += f"<sup>+</sup>"
elif target_charge > 1:
target_base_name += f"<sup>{target_charge}+</sup>"
return target_base_name
代码
文本

Enumeration function that starts with one microstate and ends with the whole protonation ensemble, given the ionization templates.

代码
文本
[60]
from rdkit.Chem import GetFormalCharge
from typing import Dict


def get_ensemble(smi: str, template_a2b: pd.DataFrame, template_b2a: pd.DataFrame, maxiter: int=10) -> Dict[int, List[str]]:
ensemble = dict()
q0 = GetFormalCharge(MolFromSmiles(smi))
ensemble[q0] = [smi]

smis_0 = [smi]

smis_0, smis_b1 = enumerate_template(smis_0, template_a2b, template_b2a, maxiter=maxiter, mode="a2b")
if smis_b1:
ensemble[q0 - 1] = smis_b1
q = q0 - 2
while True:
if q + 1 in ensemble:
_, smis_b = enumerate_template(ensemble[q + 1], template_a2b, template_b2a, maxiter=maxiter, mode="a2b")
if smis_b:
ensemble[q] = smis_b
else:
break
q -= 1

smis_a1, smis_0 = enumerate_template(smis_0, template_a2b, template_b2a, maxiter=maxiter, mode="b2a")
if smis_a1:
ensemble[q0 + 1] = smis_a1
q = q0 + 2
while True:
if q - 1 in ensemble:
smis_a, _ = enumerate_template(ensemble[q - 1], template_a2b, template_b2a, maxiter=maxiter, mode="b2a")
if smis_a:
ensemble[q] = smis_a
else:
break
q += 1
ensemble[q0] = smis_0
return ensemble


def get_neutral_base_name(ensemble: Dict[int, List[str]]) -> str:
q_list = sorted(ensemble.keys())
min_q = -int(min(q_list))
return "A" if min_q == 0 else f"H<sub>{min_q}</sub>A"

def draw_ensemble(ensemble: Dict[int, List[str]]) -> None:
q_list = sorted(ensemble.keys())
neutral_base_name = get_neutral_base_name(ensemble)
for q in q_list:
draw_macrostate(ensemble[q], calc_base_name(neutral_base_name, q))
代码
文本

The protonation ensemble of a glutamic acid when the ionization of its carboxylic group and amino group is considered.

代码
文本
[61]
GLU_ensemble = get_ensemble(smi_GLU, template_a2b, template_b2a)
draw_ensemble(GLU_ensemble)
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
代码
文本

Prediction function for fractions of microstates in the protonation ensemble at given pH.

代码
文本
[64]
from collections import defaultdict
import pylab as pl


def predict_ensemble_free_energy(smi: str, template_a2b: pd.DataFrame, template_b2a: pd.DataFrame) -> Dict[int, Tuple[str, float]]:
ensemble = get_ensemble(smi, template_a2b, template_b2a)
ensemble_free_energy = dict()
for q, macrostate in ensemble.items():
prediction = predictor.predict(macrostate)
ensemble_free_energy[q] = [(microstate, prediction[microstate]) for microstate in macrostate]
return ensemble_free_energy


def calc_distribution(ensemble_free_energy: Dict[int, Dict[str, float]], pH: float) -> Dict[int, Dict[str, float]]:
ensemble_boltzmann_factor = defaultdict(list)
partition_function = 0
for q, macrostate_free_energy in ensemble_free_energy.items():
for microstate, DfGm in macrostate_free_energy:
boltzmann_factor = math.exp(-DfGm - q * LN10 * (pH - TRANSLATE_PH))
partition_function += boltzmann_factor
ensemble_boltzmann_factor[q].append((microstate, boltzmann_factor))
return {
q: [(microstate, boltzmann_factor / partition_function) for microstate, boltzmann_factor in macrostate_boltzmann_factor]
for q, macrostate_boltzmann_factor in ensemble_boltzmann_factor.items()
}


def draw_distribution_pH(ensemble_free_energy: Dict[int, Dict[str, float]]) -> None:
pHs = np.linspace(0, 14, 1000)
fractions = defaultdict(list)
name_mapping = dict()
ensemble = defaultdict(list)
neutral_base_name = get_neutral_base_name(ensemble_free_energy)
for q, macrostate in ensemble_free_energy.items():
for i, (microstate, _) in enumerate(macrostate):
name_mapping[microstate] = f"{i+1}-{calc_base_name(neutral_base_name, q)}"
ensemble[q].append(microstate)
for pH in pHs:
distribution = calc_distribution(ensemble_free_energy, pH)
for q, macrostate_fraction in distribution.items():
for microstate, fraction in macrostate_fraction:
fractions[name_mapping[microstate]].append(fraction)
pl.figure(figsize=(14, 3), dpi=200)
for base_name, fraction_curve in fractions.items():
pl.plot(pHs, fraction_curve, label=base_name.replace("<sub>", "$_{").replace("</sub>", "}$").replace("<sup>", "$^{").replace("</sup>", "}$"))
draw_ensemble(ensemble)
pl.xlabel("pH")
pl.ylabel("fraction")
pl.legend()
pl.show()
代码
文本
[75]
GLU_ensemble_free_energy = predict_ensemble_free_energy(smi_GLU, template_a2b, template_b2a)
draw_distribution_pH(GLU_ensemble_free_energy)
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
代码
文本

Play with more complete templates

Template reading function from a csv template file.

代码
文本
[66]
def read_template(template_file: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
'''
Read a protonation template.

Params:
----
`template_file`: path of `.csv`-like template, with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags

Return:
----
`template_a2b`, `template_b2a`: acid to base and base to acid templates
'''
template = pd.read_csv(template_file, sep="\t")
template_a2b = template[template.Acid_or_base == "A"]
template_b2a = template[template.Acid_or_base == "B"]
return template_a2b, template_b2a
代码
文本

More complete ionization templates are provided in the dataset. The "simple_smarts_pattern.tsv" collects common ionization pattern in medicinal chemistry and is suitable for general purpose.

The more radical "smarts_pattern.tsv" covers all ionization pattern in our training set and was used in the original paper for macro-pKa prediction evaluation. Warning: very unreasonable protonation states in the aqueous solution may be enumerated with this template and affect distribution fraction prediction drastically in some cases!

代码
文本
[67]
template_a2b_simple, template_b2a_simple = read_template("/bohr/uni-pka-ckpt-ancf/v2/simple_smarts_pattern.tsv")
template_a2b_full, template_b2a_full = read_template("/bohr/uni-pka-ckpt-ancf/v2/smarts_pattern.tsv")
代码
文本
[68]
template_a2b_simple
                    Substructure  \
0              Sulfate monoether   
2                  Sulfonic acid   
4                  Sulfinic acid   
6                 Seleninic acid   
8                 Selenenic acid   
10                  Arsonic acid   
12             Thiosulfuric acid   
14           Phosph(o/i)nic acid   
16      Phosphate (mono/di)ether   
18                 Carboxyl acid   
20            Carboxyl acid enol   
22          Carbo(di)thioic acid   
24      Carboxyl acid vinylogue    
26              Thiol/Thiophenol   
28                        Phenol   
30  Hydroperoxide/Hydroxyl amine   
32                         Azole   
34                 Aza-aromatics   
36                         Oxime   
38                         Amine   
40                         Imine   
42                         Amide   
44                   Amide imine   
46                     Sulfamide   
48                   Phosphamide   
50                          Enol   
52              Hydrocyanic acid   
54                       Selenol   

                                               SMARTS  Index Acid_or_base  
0       [SX4:0](=[O:1])(=[O:2])(-[O:3])-[OX2:4]-[H:5]      4            A  
2   [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[OX2:4]-[H:5]      4            A  
4           [SX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4]      3            A  
6          [SeX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4]      3            A  
8                              [SeX2:0]-[OX2:1]-[H:2]      1            A  
10         [AsX4:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4]      3            A  
12          [S:0]~[SX4:1](~[O:2])(~[O:3])-[O:4]-[H:5]      4            A  
14  [PX4:0](=[O:1])(-[OX2:2]-[H:5])(-[#1,#6,#7,#8:...      2            A  
16      [PX4:0](=[O:1])(-[O:2])(-[O:3])-[OX2:4]-[H:5]      4            A  
18           [$([#6]=[#8,#7]),$(C#N):0]-[OX2:1]-[H:2]      1            A  
20          [C:0]=[C:1](-[OX2:2]-[H:3])-[OX2:4]-[H:5]      4            A  
22                [CX3:0](=[O,S:1])-[SX2,OX2:2]-[H:3]      2            A  
24              [O:0]=[C:1]-[C:2]=[C:3]-[OX2:4]-[H:5]      4            A  
26                            [#6,#7:0]-[SX2:1]-[H:2]      1            A  
28                              [c,n:0]-[OX2:1]-[H:2]      1            A  
30                              [O,N:0]-[OX2:1]-[H:2]      1            A  
32  [#7:0]1(-[H:5])-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,...      0            A  
34                                        [n:0]-[H:1]      0            A  
36  [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:...      1            A  
38  [NX4+1:0](-[H:4])(-[CX4,c,#7,#8,#1,S,$(C=C),Cl...      0            A  
40                    [#6,#7,P,S:0]=[NX3+1:1](-[H:2])      1            A  
42  [$([#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX3:1]-...      1            A  
44         [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX2:1]-[H:2]      1            A  
46              [SX4:0](=[O:1])(=[O:2])-[NX3:3]-[H:4]      3            A  
48                      [PX4:0](=[O:1])-[NX3:2]-[H:3]      2            A  
50  [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),OH1:...      3            A  
52                                  [N:0]#[C:1]-[H:2]      1            A  
54                                     [SeX2:0]-[H:1]      0            A  
代码
文本
[69]
template_a2b_full
                     Substructure  \
0               Sulfate monoether   
2                   Sulfonic acid   
4                   Sulfinic acid   
6                  Seleninic acid   
8                  Selenenic acid   
10                   Arsonic acid   
12              Thiosulfuric acid   
14            Phosph(o/i)nic acid   
16       Phosphate (mono/di)ether   
18                  Carboxyl acid   
20             Carboxyl acid enol   
22           Carbo(di)thioic acid   
24       Carboxyl acid vinylogue    
26               Thiol/Thiophenol   
28                         Phenol   
30                        Alcohol   
32                Hydroxypyridine   
34                 Methylpyridine   
36   Hydroperoxide/Hydroxyl amine   
38                          Azole   
40                  Aza-aromatics   
42     N-substitute aza-aromatics   
44                          Oxime   
46                          Amine   
48                          Imine   
50                          Amide   
52                    Amide imine   
54                      Sulfamide   
56                    Phosphamide   
58                Amide vinylogue   
60                 Di Carbonyl βH   
62                    Carbonyl βH   
64                Carbonyl allene   
66                           Enol   
68                           Enol   
70                     Acyl group   
72                      Sulfoxide   
74                      Sulfoxide   
76                      Sulfoxide   
78               Hydrocyanic acid   
80               Phosphoryl group   
82                Selenonyl group   
84                  Arsenyl group   
86                 Carboxyl group   
88       Carboxyl group vinylogue   
90                 Carbonyl group   
92                    Cyano group   
94                 Hydroxyl group   
96                        Selenol   
98                         Borate   
100                  Bromomethane   
102               Cyclopentadiene   
104                     Tin alkyl   

                                                SMARTS  Index Acid_or_base  
0        [SX4:0](=[O:1])(=[O:2])(-[O:3])-[OX2:4]-[H:5]      4            A  
2    [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[OX2:4]-[H:5]      4            A  
4            [SX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4]      3            A  
6           [SeX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4]      3            A  
8                               [SeX2:0]-[OX2:1]-[H:2]      1            A  
10          [AsX4:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4]      3            A  
12           [S:0]~[SX4:1](~[O:2])(~[O:3])-[O:4]-[H:5]      4            A  
14   [PX4:0](=[O:1])(-[OX2:2]-[H:5])(-[#1,#6,#7,#8:...      2            A  
16       [PX4:0](=[O:1])(-[O:2])(-[O:3])-[OX2:4]-[H:5]      4            A  
18            [$([#6]=[#8,#7]),$(C#N):0]-[OX2:1]-[H:2]      1            A  
20           [C:0]=[C:1](-[OX2:2]-[H:3])-[OX2:4]-[H:5]      4            A  
22                 [CX3:0](=[O,S:1])-[SX2,OX2:2]-[H:3]      2            A  
24               [O:0]=[C:1]-[C:2]=[C:3]-[OX2:4]-[H:5]      4            A  
26                             [#6,#7:0]-[SX2:1]-[H:2]      1            A  
28                               [c,n:0]-[OX2:1]-[H:2]      1            A  
30   [$([CX4]-[$([#6]=,:[#7,#8]),$([#6]=,:[#6]-,:[#...      1            A  
32                         [n:0]:[c:1]-[OH2+1:2]-[H:3]      2            A  
34   [n:0](-[C:1]=[O:2]):[c:3]:[c:4]:[c:5]-[CX4:6]-...      6            A  
36                               [O,N:0]-[OX2:1]-[H:2]      1            A  
38   [#7:0]1(-[H:5])-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,...      0            A  
40                                         [n:0]-[H:1]      0            A  
42                               [n+1:0]-[CX4:1]-[H:2]      1            A  
44   [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:...      1            A  
46   [NX4+1:0](-[H:4])(-[CX4,c,#7,#8,#1,S,$(C=C),Cl...      0            A  
48                     [#6,#7,P,S:0]=[NX3+1:1](-[H:2])      1            A  
50   [$([#6]=,:[O,S,#7:0]),$([#7]=[#7,#8]),$([#6]:,...      1            A  
52          [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX2:1]-[H:2]      1            A  
54               [SX4:0](=[O:1])(=[O:2])-[NX3:3]-[H:4]      3            A  
56                       [PX4:0](=[O:1])-[NX3:2]-[H:3]      2            A  
58   [NX3:0](-[H:5])-,:[#6:1]=,:[#6:2]-,:[$([#6]=,:...      0            A  
60   [$([#6,#7]=,:[#7,#8]),$(C#N),$([#6]=,:[#6]-,:[...      1            A  
62   [$([#6](=O)(-,:[#7+1,#6,#1])(-,:[#6,#1])),$([N...      1            A  
64               [O:0]=[C:1]-[C:2]=[C:3]=[CX3:4]-[H:5]      4            A  
66   [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),c,$(...      3            A  
68   [#6:0]=[#6:1](-[$(C=O),$(C(=C)-[OH1]):2])-[OX2...      3            A  
70                    [#6:0](-[O,N:1])=[OX2+1:2]-[H:3]      2            A  
72           [S+1:0](-[OX2:1]-[H:4])(-[#6:2])(-[#6:3])      1            A  
74           [S:0](=[OX2+1:1]-[H:4])(-[#6:2])(-[#6:3])      1            A  
76                    [S:0](=[OX2+1:1]-[H:3])(=[#6:2])      1            A  
78                                   [N:0]#[C:1]-[H:2]      1            A  
80                             [PX4:0]=[OX2+1:1]-[H:3]      1            A  
82                              [Se:0]=[OX2+1:1]-[H:3]      1            A  
84                            [AsX4:0]=[OX2+1:1]-[H:2]      1            A  
86       [#6X3:0](:,-[O,#7,S:1])=[OX2+1,SX2+1:2]-[H:3]      2            A  
88   [#6X3:0](:,-[#6:1]:,=[#6:2]:,-[O,#7,S:3])=[OX2...      4            A  
90   [#6X3:0](:,-[#1,#6:1])(:,-[#1,#6:2])=[OX2+1:3]...      3            A  
92                                   [C:0]#[N:1]-[H:2]      1            A  
94   [CX4:0](-[#6,#1:1])(-[#6,#1:2])(-[#6,#1:3])-[O...      4            A  
96                                      [SeX2:0]-[H:1]      0            A  
98                               [BX3:0]-[OX2:1]-[H:2]      1            A  
100                               [Br:0]-[CH3:1]-[H:2]      1            A  
102  [#6X4:0](-[#1:5])1-,:[#6:1]=,:[#6:2]-,:[#6:3]=...      0            A  
104                    [N+:0]-[CX4:1](-[H:3])-[SnX4:2]      1            A  
代码
文本

Here we try out the drug molecule Amoxicillin.

代码
文本
[73]
smi_drug = "CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)[C@@H](C3=CC=C(C=C3)O)N)C(=O)O)C"
drug_ensemble_free_energy = predict_ensemble_free_energy(smi_drug, template_a2b_simple, template_b2a_simple)
draw_distribution_pH(drug_ensemble_free_energy)
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
代码
文本
[77]
def predict_macro_pKa_from_macrostate(macrostate_A, macrostate_B, A_name, B_name) -> float:
DfGm_A = predictor.predict(macrostate_A)
DfGm_B = predictor.predict(macrostate_B)
draw_macrostate(macrostate_A, A_name)
draw_macrostate(macrostate_B, B_name)
return log_sum_exp(DfGm_A.values()) - log_sum_exp(DfGm_B.values()) + TRANSLATE_PH
代码
文本
[78]
drug_ensemble = get_ensemble(smi_drug, template_a2b_simple, template_b2a_simple)
q_min = min(drug_ensemble.keys())
q_max = max(drug_ensemble.keys())
neutral_base_name = get_neutral_base_name(drug_ensemble)

for i, q in enumerate(range(q_min, q_max)):
B_name = calc_base_name(neutral_base_name, q)
A_name = calc_base_name(neutral_base_name, q + 1)
macro_pKa = predict_macro_pKa_from_macrostate(drug_ensemble[q+1], drug_ensemble[q], A_name, B_name)
display_latex(f"{A_name}{B_name}, pK<sub>a{i}</sub>: {macro_pKa:.2f}", raw=True)
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
HA- →A2-, pKa0: 9.87
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
H2A →HA-, pKa1: 7.47
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
H3A+ →H2A, pKa2: 2.20
代码
文本

A very radical enumeration!

代码
文本
[76]
drug_ensemble = get_ensemble(smi_drug, template_a2b_full, template_b2a_full)
draw_ensemble(drug_ensemble)
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
<IPython.core.display.SVG object>
代码
文本
双击即可修改
代码
文本
pka
English
Machine Learning
pkaEnglishMachine Learning
点个赞吧
{/**/}