Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
HPLC retention time prediction
python
HPLC
pythonHPLC
bohr6ef000
更新于 2024-09-09
推荐镜像 :mfy:02
推荐机型 :c2_m4_cpu
HPLC_dataset(v2)

安装相关包

代码
文本
[101]
!pip install mordred
!pip install torch_geometric
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: mordred in /opt/conda/lib/python3.8/site-packages (1.2.0)
Requirement already satisfied: networkx==2.* in /opt/conda/lib/python3.8/site-packages (from mordred) (2.8.8)
Requirement already satisfied: six==1.* in /opt/conda/lib/python3.8/site-packages (from mordred) (1.16.0)
Requirement already satisfied: numpy==1.* in /opt/conda/lib/python3.8/site-packages (from mordred) (1.22.4)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: torch_geometric in /opt/conda/lib/python3.8/site-packages (2.5.3)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (4.64.1)
Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (2.28.2)
Requirement already satisfied: fsspec in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (2023.1.0)
Requirement already satisfied: scipy in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (1.7.3)
Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (1.22.4)
Requirement already satisfied: aiohttp in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (3.8.4)
Requirement already satisfied: pyparsing in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (3.0.9)
Requirement already satisfied: psutil>=5.8.0 in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (5.9.0)
Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (1.0.2)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (3.1.2)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (4.0.2)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (22.1.0)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (1.3.1)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (1.3.3)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (6.0.4)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (1.8.2)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (3.3.2)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->torch_geometric) (2.1.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->torch_geometric) (1.26.14)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->torch_geometric) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->torch_geometric) (2022.12.7)
Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->torch_geometric) (1.2.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->torch_geometric) (3.1.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[112]
import torch
from torch_geometric.nn import MessagePassing
from rdkit.Chem import Descriptors
from torch_geometric.data import Data
import argparse
import warnings
from rdkit.Chem.Descriptors import rdMolDescriptors
import pandas as pd
import os
from mordred import Calculator, descriptors, is_missing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdchem
import pandas as pd
from torch_geometric.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
代码
文本

定义函数以转化分子特征及将分子转化为图

代码
文本
[103]
device=torch.device("cpu")
DAY_LIGHT_FG_SMARTS_LIST = [
# C
"[CX4]",
"[$([CX2](=C)=C)]",
"[$([CX3]=[CX3])]",
"[$([CX2]#C)]",
# C & O
"[CX3]=[OX1]",
"[$([CX3]=[OX1]),$([CX3+]-[OX1-])]",
"[CX3](=[OX1])C",
"[OX1]=CN",
"[CX3](=[OX1])O",
"[CX3](=[OX1])[F,Cl,Br,I]",
"[CX3H1](=O)[#6]",
"[CX3](=[OX1])[OX2][CX3](=[OX1])",
"[NX3][CX3](=[OX1])[#6]",
"[NX3][CX3]=[NX3+]",
"[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]",
"[NX3][CX3](=[OX1])[OX2H0]",
"[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]",
"[CX3](=O)[O-]",
"[CX3](=[OX1])(O)O",
"[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]",
"C[OX2][CX3](=[OX1])[OX2]C",
"[CX3](=O)[OX2H1]",
"[CX3](=O)[OX1H0-,OX2H1]",
"[NX3][CX2]#[NX1]",
"[#6][CX3](=O)[OX2H0][#6]",
"[#6][CX3](=O)[#6]",
"[OD2]([#6])[#6]",
# H
"[H]",
"[!#1]",
"[H+]",
"[+H]",
"[!H]",
# N
"[NX3;H2,H1;!$(NC=O)]",
"[NX3][CX3]=[CX3]",
"[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
"[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]",
"[NX3][$(C=C),$(cc)]",
"[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]",
"[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]",
"[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]",
"[CH3X4]",
"[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]",
"[CH2X4][CX3](=[OX1])[NX3H2]",
"[CH2X4][CX3](=[OX1])[OH0-,OH]",
"[CH2X4][SX2H,SX1H0-]",
"[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]",
"[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]",
"[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\
[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1",
"[CHX4]([CH3X4])[CH2X4][CH3X4]",
"[CH2X4][CHX4]([CH3X4])[CH3X4]",
"[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]",
"[CH2X4][CH2X4][SX2][CH3X4]",
"[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1",
"[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]",
"[CH2X4][OX2H]",
"[NX3][CX3]=[SX1]",
"[CHX4]([CH3X4])[OX2H]",
"[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12",
"[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1",
"[CHX4]([CH3X4])[CH3X4]",
"N[CX4H2][CX3](=[OX1])[O,N]",
"N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]",
"[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]",
"[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]",
"[#7]",
"[NX2]=N",
"[NX2]=[NX2]",
"[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]",
"[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]",
"[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]",
"[NX3][NX3]",
"[NX3][NX2]=[*]",
"[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]",
"[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]",
"[NX3+]=[CX3]",
"[CX3](=[OX1])[NX3H][CX3](=[OX1])",
"[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])",
"[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])",
"[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]",
"[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]",
"[NX1]#[CX2]",
"[CX1-]#[NX2+]",
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
"[NX2]=[OX1]",
"[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]",
# O
"[OX2H]",
"[#6][OX2H]",
"[OX2H][CX3]=[OX1]",
"[OX2H]P",
"[OX2H][#6X3]=[#6]",
"[OX2H][cX3]:[c]",
"[OX2H][$(C=C),$(cc)]",
"[$([OH]-*=[!#6])]",
"[OX2,OX1-][OX2,OX1-]",
# P
"[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\
$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\
,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]",
"[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\
$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\
$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]",
# S
"[S-][CX3](=S)[#6]",
"[#6X3](=[SX1])([!N])[!N]",
"[SX2]",
"[#16X2H]",
"[#16!H0]",
"[#16X2H0]",
"[#16X2H0][!#16]",
"[#16X2H0][#16X2H0]",
"[#16X2H0][!#16].[#16X2H0][!#16]",
"[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]",
"[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]",
"[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]",
"[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]",
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]",
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]",
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]",
"[SX4](C)(C)(=O)=N",
"[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]",
"[$([#16X3]=[OX1]),$([#16X3+][OX1-])]",
"[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]",
"[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]",
"[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]",
"[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]",
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]",
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]",
"[#16X2][OX2H,OX1H0-]",
"[#16X2][OX2H0]",
# X
"[#6][F,Cl,Br,I]",
"[F,Cl,Br,I]",
"[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]",
]


def get_gasteiger_partial_charges(mol, n_iter=12):
"""
Calculates list of gasteiger partial charges for each atom in mol object.
Args:
mol: rdkit mol object.
n_iter(int): number of iterations. Default 12.
Returns:
list of computed partial charges for each atom.
"""
Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
throwOnParamFailure=True)
partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
mol.GetAtoms()]
return partial_charges


def create_standardized_mol_id(smiles):
"""
Args:
smiles: smiles sequence.
Returns:
inchi.
"""
if check_smiles_validity(smiles):
# remove stereochemistry
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
isomericSmiles=False)
mol = AllChem.MolFromSmiles(smiles)
if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
if '.' in smiles: # if multiple species, pick largest molecule
mol_species_list = split_rdkit_mol_obj(mol)
largest_mol = get_largest_mol(mol_species_list)
inchi = AllChem.MolToInchi(largest_mol)
else:
inchi = AllChem.MolToInchi(mol)
return inchi
else:
return
else:
return


def check_smiles_validity(smiles):
"""
Check whether the smile can't be converted to rdkit mol object.
"""
try:
m = Chem.MolFromSmiles(smiles)
if m:
return True
else:
return False
except Exception as e:
return False


def split_rdkit_mol_obj(mol):
"""
Split rdkit mol object containing multiple species or one species into a
list of mol objects or a list containing a single object respectively.
Args:
mol: rdkit mol object.
"""
smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
smiles_list = smiles.split('.')
mol_species_list = []
for s in smiles_list:
if check_smiles_validity(s):
mol_species_list.append(AllChem.MolFromSmiles(s))
return mol_species_list


def get_largest_mol(mol_list):
"""
Given a list of rdkit mol objects, returns mol object containing the
largest num of atoms. If multiple containing largest num of atoms,
picks the first one.
Args:
mol_list(list): a list of rdkit mol object.
Returns:
the largest mol.
"""
num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
return mol_list[largest_mol_idx]


def rdchem_enum_to_list(values):
"""values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER}
"""
return [values[i] for i in range(len(values))]


def safe_index(alist, elem):
"""
Return index of element e in list l. If e is not present, return the last index
"""
try:
return alist.index(elem)
except ValueError:
return len(alist) - 1


def get_atom_feature_dims(list_acquired_feature_names):
""" tbd
"""
return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names]))


def get_bond_feature_dims(list_acquired_feature_names):
""" tbd
"""
list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names]))
# +1 for self loop edges
return [_l + 1 for _l in list_bond_feat_dim]


class CompoundKit(object):
"""
CompoundKit
"""
atom_vocab_dict = {
"atomic_num": list(range(1, 119)) + ['misc'],
"chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values),
"degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
"explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'],
"formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
"hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values),
"implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'],
"is_aromatic": [0, 1],
"total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'num_radical_e': [0, 1, 2, 3, 4, 'misc'],
'atom_is_in_ring': [0, 1],
'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
}
bond_vocab_dict = {
"bond_dir": rdchem_enum_to_list(rdchem.BondDir.values),
"bond_type": rdchem_enum_to_list(rdchem.BondType.values),
"is_in_ring": [0, 1],

'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values),
'is_conjugated': [0, 1],
}
# float features
atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass']
# bond_float_feats= ["bond_length", "bond_angle"] # optional

### functional groups
day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST
day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list]

morgan_fp_N = 200
morgan2048_fp_N = 2048
maccs_fp_N = 167

period_table = Chem.GetPeriodicTable()

### atom

@staticmethod
def get_atom_value(atom, name):
"""get atom values"""
if name == 'atomic_num':
return atom.GetAtomicNum()
elif name == 'chiral_tag':
return atom.GetChiralTag()
elif name == 'degree':
return atom.GetDegree()
elif name == 'explicit_valence':
return atom.GetExplicitValence()
elif name == 'formal_charge':
return atom.GetFormalCharge()
elif name == 'hybridization':
return atom.GetHybridization()
elif name == 'implicit_valence':
return atom.GetImplicitValence()
elif name == 'is_aromatic':
return int(atom.GetIsAromatic())
elif name == 'mass':
return int(atom.GetMass())
elif name == 'total_numHs':
return atom.GetTotalNumHs()
elif name == 'num_radical_e':
return atom.GetNumRadicalElectrons()
elif name == 'atom_is_in_ring':
return int(atom.IsInRing())
elif name == 'valence_out_shell':
return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())
else:
raise ValueError(name)

@staticmethod
def get_atom_feature_id(atom, name):
"""get atom features id"""
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name))

@staticmethod
def get_atom_feature_size(name):
"""get atom features size"""
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
return len(CompoundKit.atom_vocab_dict[name])

### bond

@staticmethod
def get_bond_value(bond, name):
"""get bond values"""
if name == 'bond_dir':
return bond.GetBondDir()
elif name == 'bond_type':
return bond.GetBondType()
elif name == 'is_in_ring':
return int(bond.IsInRing())
elif name == 'is_conjugated':
return int(bond.GetIsConjugated())
elif name == 'bond_stereo':
return bond.GetStereo()
else:
raise ValueError(name)

@staticmethod
def get_bond_feature_id(bond, name):
"""get bond features id"""
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name))

@staticmethod
def get_bond_feature_size(name):
"""get bond features size"""
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
return len(CompoundKit.bond_vocab_dict[name])

### fingerprint

@staticmethod
def get_morgan_fingerprint(mol, radius=2):
"""get morgan fingerprint"""
nBits = CompoundKit.morgan_fp_N
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
return [int(b) for b in mfp.ToBitString()]

@staticmethod
def get_morgan2048_fingerprint(mol, radius=2):
"""get morgan2048 fingerprint"""
nBits = CompoundKit.morgan2048_fp_N
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
return [int(b) for b in mfp.ToBitString()]

@staticmethod
def get_maccs_fingerprint(mol):
"""get maccs fingerprint"""
fp = AllChem.GetMACCSKeysFingerprint(mol)
return [int(b) for b in fp.ToBitString()]

### functional groups

@staticmethod
def get_daylight_functional_group_counts(mol):
"""get daylight functional group counts"""
fg_counts = []
for fg_mol in CompoundKit.day_light_fg_mo_list:
sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True)
fg_counts.append(len(sub_structs))
return fg_counts

@staticmethod
def get_ring_size(mol):
"""return (N,6) list"""
rings = mol.GetRingInfo()
rings_info = []
for r in rings.AtomRings():
rings_info.append(r)
ring_list = []
for atom in mol.GetAtoms():
atom_result = []
for ringsize in range(3, 9):
num_of_ring_at_ringsize = 0
for r in rings_info:
if len(r) == ringsize and atom.GetIdx() in r:
num_of_ring_at_ringsize += 1
if num_of_ring_at_ringsize > 8:
num_of_ring_at_ringsize = 9
atom_result.append(num_of_ring_at_ringsize)

ring_list.append(atom_result)
return ring_list

@staticmethod
def atom_to_feat_vector(atom):
""" tbd """
atom_names = {
"atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()),
"chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()),
"degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()),
"explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()),
"formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()),
"hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()),
"implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()),
"is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())),
"total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()),
'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()),
'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())),
'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'],
CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())),
'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()),
'partial_charge': CompoundKit.check_partial_charge(atom),
'mass': atom.GetMass(),
}
return atom_names

@staticmethod
def get_atom_names(mol):
"""get atom name list
TODO: to be remove in the future
"""
atom_features_dicts = []
Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
for i, atom in enumerate(mol.GetAtoms()):
atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom))

ring_list = CompoundKit.get_ring_size(mol)
for i, atom in enumerate(mol.GetAtoms()):
atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0])
atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1])
atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2])
atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3])
atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4])
atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5])

return atom_features_dicts

@staticmethod
def check_partial_charge(atom):
"""tbd"""
pc = atom.GetDoubleProp('_GasteigerCharge')
if pc != pc:
# unsupported atom, replace nan with 0
pc = 0
if pc == float('inf'):
# max 4 for other atoms, set to 10 here if inf is get
pc = 10
return pc


class Compound3DKit(object):
"""the 3Dkit of Compound"""

@staticmethod
def get_atom_poses(mol, conf):
"""tbd"""
atom_poses = []
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomicNum() == 0:
return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms())
pos = conf.GetAtomPosition(i)
atom_poses.append([pos.x, pos.y, pos.z])
return atom_poses

@staticmethod
def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False):
"""the atoms of mol will be changed in some cases."""
conf = mol.GetConformer()
atom_poses = Compound3DKit.get_atom_poses(mol, conf)
return mol,atom_poses
# try:
# new_mol = Chem.AddHs(mol)
# res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
# ### MMFF generates multiple conformations
# res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
# new_mol = Chem.RemoveHs(new_mol)
# index = np.argmin([x[1] for x in res])
# energy = res[index][1]
# conf = new_mol.GetConformer(id=int(index))
# except:
# new_mol = mol
# AllChem.Compute2DCoords(new_mol)
# energy = 0
# conf = new_mol.GetConformer()
#
# atom_poses = Compound3DKit.get_atom_poses(new_mol, conf)
# if return_energy:
# return new_mol, atom_poses, energy
# else:
# return new_mol, atom_poses

@staticmethod
def get_2d_atom_poses(mol):
"""get 2d atom poses"""
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
atom_poses = Compound3DKit.get_atom_poses(mol, conf)
return atom_poses

@staticmethod
def get_bond_lengths(edges, atom_poses):
"""get bond lengths"""
bond_lengths = []
for src_node_i, tar_node_j in edges:
bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i]))
bond_lengths = np.array(bond_lengths, 'float32')
return bond_lengths

@staticmethod
def get_superedge_angles(edges, atom_poses, dir_type='HT'):
"""get superedge angles"""

def _get_vec(atom_poses, edge):
return atom_poses[edge[1]] - atom_poses[edge[0]]

def _get_angle(vec1, vec2):
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0
vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors
vec2 = vec2 / (norm2 + 1e-5)
angle = np.arccos(np.dot(vec1, vec2))
return angle

E = len(edges)
edge_indices = np.arange(E)
super_edges = []
bond_angles = []
bond_angle_dirs = []
for tar_edge_i in range(E):
tar_edge = edges[tar_edge_i]
if dir_type == 'HT':
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]]
elif dir_type == 'HH':
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]]
else:
raise ValueError(dir_type)
for src_edge_i in src_edge_indices:
if src_edge_i == tar_edge_i:
continue
src_edge = edges[src_edge_i]
src_vec = _get_vec(atom_poses, src_edge)
tar_vec = _get_vec(atom_poses, tar_edge)
super_edges.append([src_edge_i, tar_edge_i])
angle = _get_angle(src_vec, tar_vec)
bond_angles.append(angle)
bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T

if len(super_edges) == 0:
super_edges = np.zeros([0, 2], 'int64')
bond_angles = np.zeros([0, ], 'float32')
else:
super_edges = np.array(super_edges, 'int64')
bond_angles = np.array(bond_angles, 'float32')
return super_edges, bond_angles, bond_angle_dirs


def new_smiles_to_graph_data(smiles, **kwargs):
"""
Convert smiles to graph data.
"""
mol = AllChem.MolFromSmiles(smiles)
if mol is None:
return None
data = new_mol_to_graph_data(mol)
return data


def new_mol_to_graph_data(mol):
"""
mol_to_graph_data
Args:
atom_features: Atom features.
edge_features: Edge features.
morgan_fingerprint: Morgan fingerprint.
functional_groups: Functional groups.
"""
if len(mol.GetAtoms()) == 0:
return None

atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names
bond_id_names = list(CompoundKit.bond_vocab_dict.keys())

data = {}

### atom features
data = {name: [] for name in atom_id_names}

raw_atom_feat_dicts = CompoundKit.get_atom_names(mol)
for atom_feat in raw_atom_feat_dicts:
for name in atom_id_names:
data[name].append(atom_feat[name])

### bond and bond features
for name in bond_id_names:
data[name] = []
data['edges'] = []

for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# i->j and j->i
data['edges'] += [(i, j), (j, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name)
data[name] += [bond_feature_id] * 2

#### self loop
N = len(data[atom_id_names[0]])
for i in range(N):
data['edges'] += [(i, i)]
for name in bond_id_names:
bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1
data[name] += [bond_feature_id] * N

### make ndarray and check length
for name in list(CompoundKit.atom_vocab_dict.keys()):
data[name] = np.array(data[name], 'int64')
for name in CompoundKit.atom_float_names:
data[name] = np.array(data[name], 'float32')
for name in bond_id_names:
data[name] = np.array(data[name], 'int64')
data['edges'] = np.array(data['edges'], 'int64')

### morgan fingerprint
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
return data


def mol_to_graph_data(mol):
"""
mol_to_graph_data
Args:
atom_features: Atom features.
edge_features: Edge features.
morgan_fingerprint: Morgan fingerprint.
functional_groups: Functional groups.
"""
if len(mol.GetAtoms()) == 0:
return None

atom_id_names = [
"atomic_num", "chiral_tag", "degree", "explicit_valence",
"formal_charge", "hybridization", "implicit_valence",
"is_aromatic", "total_numHs",
]
bond_id_names = [
"bond_dir", "bond_type", "is_in_ring",
]

data = {}
for name in atom_id_names:
data[name] = []
data['mass'] = []
for name in bond_id_names:
data[name] = []
data['edges'] = []

### atom features
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomicNum() == 0:
return None
for name in atom_id_names:
data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV
data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01)

### bond features
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# i->j and j->i
data['edges'] += [(i, j), (j, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV
data[name] += [bond_feature_id] * 2

### self loop (+2)
N = len(data[atom_id_names[0]])
for i in range(N):
data['edges'] += [(i, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop
data[name] += [bond_feature_id] * N

### check whether edge exists
if len(data['edges']) == 0: # mol has no bonds
for name in bond_id_names:
data[name] = np.zeros((0,), dtype="int64")
data['edges'] = np.zeros((0, 2), dtype="int64")

### make ndarray and check length
for name in atom_id_names:
data[name] = np.array(data[name], 'int64')
data['mass'] = np.array(data['mass'], 'float32')
for name in bond_id_names:
data[name] = np.array(data[name], 'int64')
data['edges'] = np.array(data['edges'], 'int64')

### morgan fingerprint
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
return data


def mol_to_geognn_graph_data(mol, atom_poses, dir_type):
"""
mol: rdkit molecule
dir_type: direction type for bond_angle grpah
"""
if len(mol.GetAtoms()) == 0:
return None

data = mol_to_graph_data(mol)

data['atom_pos'] = np.array(atom_poses, 'float32')
data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos'])
BondAngleGraph_edges, bond_angles, bond_angle_dirs = \
Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos'])
data['BondAngleGraph_edges'] = BondAngleGraph_edges
data['bond_angle'] = np.array(bond_angles, 'float32')
return data


def mol_to_geognn_graph_data_MMFF3d(mol):
"""tbd"""
if len(mol.GetAtoms()) <= 400:
mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10)
else:
atom_poses = Compound3DKit.get_2d_atom_poses(mol)
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')


def mol_to_geognn_graph_data_raw3d(mol):
"""tbd"""
atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer())
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')

def obtain_3D_mol(smiles,name):
mol = AllChem.MolFromSmiles(smiles)
new_mol = Chem.AddHs(mol)
res = AllChem.EmbedMultipleConfs(new_mol)
### MMFF generates multiple conformations
res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
new_mol = Chem.RemoveHs(new_mol)
Chem.MolToMolFile(new_mol, name+'.mol')
return new_mol
代码
文本

相关参数设置

代码
文本
[104]
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
warnings.filterwarnings('ignore')
Use_column_info=False
Use_geometry_enhanced=True #default:True
atom_id_names = [
"atomic_num", "chiral_tag", "degree", "explicit_valence",
"formal_charge", "hybridization", "implicit_valence",
"is_aromatic", "total_numHs",
]
bond_id_names = [
"bond_dir", "bond_type", "is_in_ring"]
condition_name=['silica_surface','replace_basis']
condition_float_name=['eluent','grain_radian']
if Use_geometry_enhanced==True:
bond_float_names = ["bond_length",'prop']

if Use_geometry_enhanced==False:
bond_float_names=['prop']

bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS']

column_specify={'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4],
'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9],
'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14],
'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19],
'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24]}
column_smile=['O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4']
column_name=['ADH','ODH','IC','IA','OJH','ASH','IC3','IE','ID','OD3', 'IB','AD','AD3',
'IF','OD','AS','OJ3','IG','AZ','IAH','OJ','ICH','OZ3','IF3','IAU']
full_atom_feature_dims = get_atom_feature_dims(atom_id_names)
full_bond_feature_dims = get_bond_feature_dims(bond_id_names)


if Use_column_info==True:
bond_id_names.extend(['coated', 'immobilized'])
bond_float_names.extend(['diameter'])
if Use_geometry_enhanced==True:
bond_angle_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS'])
else:
bond_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS'])
full_bond_feature_dims.extend([2,2])

calc = Calculator(descriptors, ignore_3D=False)
代码
文本

定义图神经网络

代码
文本
[105]
class AtomEncoder(torch.nn.Module):
"""该类用于对原子属性做嵌入。
记`N`为原子属性的维度,则原子属性表示为`[x1, x2, ..., xi, xN]`,其中任意的一维度`xi`都是类别型数据。full_atom_feature_dims[i]存储了原子属性`xi`的类别数量。
该类将任意的原子属性`[x1, x2, ..., xi, xN]`转换为原子的嵌入`x_embedding`(维度为emb_dim)。
"""

def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()

self.atom_embedding_list = torch.nn.ModuleList()

for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim + 5, emb_dim) # 不同维度的属性用不同的Embedding方法
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)

def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:, i])

return x_embedding

class BondEncoder(torch.nn.Module):

def __init__(self, emb_dim):
super(BondEncoder, self).__init__()

self.bond_embedding_list = torch.nn.ModuleList()

for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim + 5, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)

def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])

return bond_embedding

class RBF(torch.nn.Module):
"""
Radial Basis Function
"""

def __init__(self, centers, gamma, dtype='float32'):
super(RBF, self).__init__()
self.centers = centers.reshape([1, -1])
self.gamma = gamma

def forward(self, x):
"""
Args:
x(tensor): (-1, 1).
Returns:
y(tensor): (-1, n_centers)
"""
x = x.reshape([-1, 1])
return torch.exp(-self.gamma * torch.square(x - self.centers))

class BondFloatRBF(torch.nn.Module):
"""
Bond Float Encoder using Radial Basis Functions
"""

def __init__(self, bond_float_names, embed_dim, rbf_params=None):
super(BondFloatRBF, self).__init__()
self.bond_float_names = bond_float_names

if rbf_params is None:
self.rbf_params = {
'bond_length': (nn.Parameter(torch.arange(0, 2, 0.1)), nn.Parameter(torch.Tensor([10.0]))),
# (centers, gamma)
'prop': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
# 'TPSA':(nn.Parameter(torch.arange(0, 100, 5).to(torch.float32)), nn.Parameter(torch.Tensor([5.0]))),
# 'RASA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
# 'RPSA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0])))
}
else:
self.rbf_params = rbf_params

self.linear_list = torch.nn.ModuleList()
self.rbf_list = torch.nn.ModuleList()
for name in self.bond_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers.to(device), gamma.to(device))
self.rbf_list.append(rbf)
linear = torch.nn.Linear(len(centers), embed_dim).to(device)
self.linear_list.append(linear)

def forward(self, bond_float_features):
"""
Args:
bond_float_features(dict of tensor): bond float features.
"""
out_embed = 0
for i, name in enumerate(self.bond_float_names):
x = bond_float_features[:, i].reshape(-1, 1)
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
return out_embed

class BondAngleFloatRBF(torch.nn.Module):
"""
Bond Angle Float Encoder using Radial Basis Functions
"""

def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None):
super(BondAngleFloatRBF, self).__init__()
self.bond_angle_float_names = bond_angle_float_names

if rbf_params is None:
self.rbf_params = {
'bond_angle': (nn.Parameter(torch.arange(0, torch.pi, 0.1)), nn.Parameter(torch.Tensor([10.0]))),
# (centers, gamma)
}
else:
self.rbf_params = rbf_params

self.linear_list = torch.nn.ModuleList()
self.rbf_list = torch.nn.ModuleList()
for name in self.bond_angle_float_names:
if name == 'bond_angle':
centers, gamma = self.rbf_params[name]
rbf = RBF(centers.to(device), gamma.to(device))
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim)
self.linear_list.append(linear)
else:
linear = nn.Linear(len(self.bond_angle_float_names) - 1, embed_dim)
self.linear_list.append(linear)
break

def forward(self, bond_angle_float_features):
"""
Args:
bond_angle_float_features(dict of tensor): bond angle float features.
"""
out_embed = 0
for i, name in enumerate(self.bond_angle_float_names):
if name == 'bond_angle':
x = bond_angle_float_features[:, i].reshape(-1, 1)
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
else:
x = bond_angle_float_features[:, 1:]
out_embed += self.linear_list[i](x)
break
return out_embed

class ConditionEmbeding(torch.nn.Module):
"""
Not used in single_column prediction
"""

def __init__(self, condition_names,condition_float_names, embed_dim, rbf_params=None):
super(ConditionEmbeding, self).__init__()
self.condition_names = condition_names
self.condition_float_names=condition_float_names

if rbf_params is None:
self.rbf_params = {
'eluent': (nn.Parameter(torch.arange(0,1,0.1)), nn.Parameter(torch.Tensor([10.0]))),
'grain_radian': (nn.Parameter(torch.arange(0,10,0.1)), nn.Parameter(torch.Tensor([10.0])))# (centers, gamma)
}
else:
self.rbf_params = rbf_params

self.linear_list = torch.nn.ModuleList()
self.rbf_list = torch.nn.ModuleList()
self.embedding_list=torch.nn.ModuleList()
for name in self.condition_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers.to(device), gamma.to(device))
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim).to(device)
self.linear_list.append(linear)
for name in self.condition_names:
if name=='silica_surface':
emb = torch.nn.Embedding(2 + 5, embed_dim).to(device)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.embedding_list.append(emb)
elif name=='replace_basis':
emb = torch.nn.Embedding(6 + 5, embed_dim).to(device)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.embedding_list.append(emb)

def forward(self, condition):
"""
Args:
bond_angle_float_features(dict of tensor): bond angle float features.
"""
out_embed = 0
for i, name in enumerate(self.condition_float_names):
x = condition[:,2*i+1]
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
for i, name in enumerate(self.condition_names):
x = self.embedding_list[i](condition[:,2*i].to(torch.int64))
out_embed += x
return out_embed

class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''

super(GINConv, self).__init__(aggr="add")

self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))

def forward(self, x, edge_index, edge_attr):
edge_embedding = edge_attr
out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out

def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)

def update(self, aggr_out):
return aggr_out

class GINNodeEmbedding(torch.nn.Module):
"""
Node embedding
"""

def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module
"""

super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual

if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")

self.atom_encoder = AtomEncoder(emb_dim)
self.bond_encoder=BondEncoder(emb_dim)
self.bond_float_encoder=BondFloatRBF(bond_float_names,emb_dim)
self.bond_angle_encoder=BondAngleFloatRBF(bond_angle_float_names,emb_dim)
self.condition_encoder=ConditionEmbeding(condition_name,condition_float_name,emb_dim) #Not used in single_column prediction
# List of GNNs
self.convs = torch.nn.ModuleList()
self.convs_bond_angle=torch.nn.ModuleList()
self.convs_bond_float=torch.nn.ModuleList()
self.convs_bond_embeding=torch.nn.ModuleList()
self.convs_angle_float=torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.batch_norms_ba = torch.nn.ModuleList()
self.convs_condition=torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.convs_bond_angle.append(GINConv(emb_dim))
self.convs_bond_embeding.append(BondEncoder(emb_dim))
self.convs_bond_float.append(BondFloatRBF(bond_float_names,emb_dim))
self.convs_angle_float.append(BondAngleFloatRBF(bond_angle_float_names,emb_dim))
self.convs_condition.append(ConditionEmbeding(condition_name,condition_float_name,emb_dim)) #Not used in single_column prediction
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
self.batch_norms_ba.append(torch.nn.BatchNorm1d(emb_dim))

def forward(self, batched_atom_bond,batched_bond_angle):
x, edge_index, edge_attr = batched_atom_bond.x, batched_atom_bond.edge_index, batched_atom_bond.edge_attr
edge_index_ba,edge_attr_ba= batched_bond_angle.edge_index, batched_bond_angle.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入
h_list_ba=[self.bond_float_encoder(edge_attr[:,3:edge_attr.shape[1]+1].to(torch.float32))+self.bond_encoder(edge_attr[:,0:3].to(torch.int64))]
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer])
cur_h_ba=self.convs_bond_embeding[layer](edge_attr[:,0:3].to(torch.int64))+self.convs_bond_float[layer](edge_attr[:,3:edge_attr.shape[1]+1].to(torch.float32))
cur_angle_hidden=self.convs_angle_float[layer](edge_attr_ba)
h_ba=self.convs_bond_angle[layer](cur_h_ba, edge_index_ba, cur_angle_hidden)

if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h_ba = F.dropout(F.relu(h_ba), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_ba+=h_list_ba[layer]
h_list.append(h)
h_list_ba.append(h_ba)


# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
edge_representation = h_list_ba[-1]
elif self.JK == "sum":
node_representation = 0
edge_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
edge_representation += h_list_ba[layer]

return node_representation,edge_representation

class GINGraphPooling(nn.Module):

def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="attention",
descriptor_dim=1781):
"""GIN Graph Pooling Module
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. Defaults to "sum".

Out:
graph representation
"""
super(GINGraphPooling, self).__init__()

self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.descriptor_dim=descriptor_dim
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")

self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)

# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))

elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")

if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

self.NN_descriptor = nn.Sequential(nn.Linear(self.descriptor_dim, self.emb_dim),
nn.Sigmoid(),
nn.Linear(self.emb_dim, self.emb_dim))

self.sigmoid = nn.Sigmoid()

def forward(self, batched_atom_bond,batched_bond_angle):
h_node,h_node_ba= self.gnn_node(batched_atom_bond,batched_bond_angle)
h_graph = self.pool(h_node, batched_atom_bond.batch)
output = self.graph_pred_linear(h_graph)

if self.training:
return output,h_graph
else:
# At inference time, relu is applied to output to ensure positivity
return torch.clamp(output, min=0, max=1e8),h_graph

def mord(mol, nBits=1826, errors_as_zeros=True):
try:
result = calc(mol)
desc_list = [r if not is_missing(r) else 0 for r in result]
np_arr = np.array(desc_list)
return np_arr
except:
return np.NaN if not errors_as_zeros else np.zeros((nBits,), dtype=np.float32)

def load_3D_mol():
dir = 'mol_save/'
for root, dirs, files in os.walk(dir):
file_names = files
file_names.sort(key=lambda x: int(x[x.find('_') + 5:x.find(".")])) # 按照前面的数字字符排序
mol_save = []
for file_name in file_names:
mol_save.append(Chem.MolFromMolFile(dir + file_name))
return mol_save

def parse_args():
parser = argparse.ArgumentParser(description='Graph data miming with GNN')
parser.add_argument('--task_name', type=str, default='GINGraphPooling',
help='task name')
parser.add_argument('--device', type=str, default='cpu',
help='which gpu to use if any (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--emb_dim', type=int, default=128,
help='dimensionality of hidden units in GNNs (default: 256)')
parser.add_argument('--drop_ratio', type=float, default=0.,
help='dropout ratio (default: 0.)')
parser.add_argument('--save_test', action='store_true')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training (default: 512)')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train (default: 100)')
parser.add_argument('--weight_decay', type=float, default=0.00001,
help='weight decay')
parser.add_argument('--early_stop', type=int, default=10,
help='early stop (default: 10)')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
parser.add_argument('--dataset_root', type=str, default="dataset",
help='dataset root')
args = parser.parse_args(args=[])

return args

def calc_dragon_type_desc(mol):
compound_mol = mol
compound_MolWt = Descriptors.ExactMolWt(compound_mol)
compound_TPSA = Chem.rdMolDescriptors.CalcTPSA(compound_mol)
compound_nRotB = Descriptors.NumRotatableBonds(compound_mol) # Number of rotable bonds
compound_HBD = Descriptors.NumHDonors(compound_mol) # Number of H bond donors
compound_HBA = Descriptors.NumHAcceptors(compound_mol) # Number of H bond acceptors
compound_LogP = Descriptors.MolLogP(compound_mol) # LogP
return rdMolDescriptors.CalcAUTOCORR3D(mol) + rdMolDescriptors.CalcMORSE(mol) + \
rdMolDescriptors.CalcRDF(mol) + rdMolDescriptors.CalcWHIM(mol) + \
[compound_MolWt, compound_TPSA, compound_nRotB, compound_HBD, compound_HBA, compound_LogP]
def prepartion(args):
save_dir = os.path.join('saves', args.task_name)
args.save_dir = save_dir
os.makedirs(args.save_dir, exist_ok=True)
args.device = torch.device("cpu")
args.output_file = open(os.path.join(args.save_dir, 'output'), 'a')
print(args, file=args.output_file, flush=True)

def q_loss(q,y_true,y_pred):
e = (y_true-y_pred)
return torch.mean(torch.maximum(q*e, (q-1)*e))
代码
文本

构建图数据集的函数

代码
文本
[106]
def Construct_dataset(dataset,data_index, T1, speed, eluent,column):
graph_atom_bond = []
graph_bond_angle = []
big_index = []
if column=='ODH':
all_descriptor=np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_ODH_morder.npy')
if column=='ADH':
all_descriptor = np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_ADH_charity_morder_0606.npy')
if column=='IC':
all_descriptor=np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_IC_charity_morder_0823.npy')
if column == 'IA':
all_descriptor = np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_IA_charity_morder_0823.npy')

for i in range(len(dataset)):
data = dataset[i]
atom_feature = []
bond_feature = []
for name in atom_id_names:
atom_feature.append(data[name])
for name in bond_id_names:
bond_feature.append(data[name])
atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64)
bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64)
bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32))
bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32))
y = torch.Tensor([float(T1[i]) * float(speed[i])])
edge_index = torch.from_numpy(data['edges'].T).to(torch.int64)
bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64)
data_index_int=torch.from_numpy(np.array(data_index[i])).to(torch.int64)

prop=torch.ones([bond_feature.shape[0]])*eluent[i]

TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820]/100
RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821]
RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822]
MDEC=torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568]
MATS=torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457]

bond_feature=torch.cat([bond_feature,bond_float_feature.reshape(-1,1)],dim=1)
bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1)
bond_angle_feature=bond_angle_feature.reshape(-1,1)
bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1)

if y[0]>60:
big_index.append(i)
continue

data_atom_bond = Data(atom_feature, edge_index, bond_feature, y,data_index=data_index_int)
data_bond_angle= Data(edge_index=bond_index, edge_attr=bond_angle_feature)
graph_atom_bond.append(data_atom_bond)
graph_bond_angle.append(data_bond_angle)
return graph_atom_bond,graph_bond_angle,big_index
代码
文本

设置控制程序的参数,可以通过修改Use_column改变训练的柱子模型(可选:ADH,ODH,IC,IA)

代码
文本
[107]
#============Parameter setting===============
test_mode='fixed' #fixed or random or enantiomer(extract enantimoers)
Use_column='ODH' #trail name
代码
文本

读取数据

代码
文本
[108]
#-------------load data----------------
'''
The Graph construction is prepared and saved beforehand to accelerate the process by the code:
for smile in smiles:
mol = obtain_3D_mol(smile, 'trail')
mol = Chem.MolFromMolFile(f"trail.mol")
all_descriptor.append(mord(mol))
dataset.append(mol_to_geognn_graph_data_MMFF3d(mol))
'''

if Use_column=='ODH':
HPLC_ODH=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/ODH_charity_0616.csv')
HPLC_ODH=HPLC_ODH.drop(4231) #conformer error
all_smile_ODH = HPLC_ODH['SMILES'].values
T1_ODH = HPLC_ODH['RT'].values
Speed_ODH = HPLC_ODH['Speed'].values
Prop_ODH = HPLC_ODH['i-PrOH_proportion'].values
dataset_ODH=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_ODH.npy',allow_pickle=True).tolist()
index_ODH=HPLC_ODH['Unnamed: 0'].values
print("data_num:",len(dataset_ODH))

if Use_column=='ADH':
HPLC_ADH=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/ADH_charity_0606.csv')
all_smile_ADH = HPLC_ADH['SMILES'].values
T1_ADH = HPLC_ADH['RT'].values
Speed_ADH = HPLC_ADH['Speed'].values
Prop_ADH = HPLC_ADH['i-PrOH_proportion'].values
dataset_ADH=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_ADH_charity_0606.npy',allow_pickle=True).tolist()
index_ADH=HPLC_ADH['Unnamed: 0'].values
print("data_num:",len(dataset_ADH))

if Use_column=='IC':
HPLC_IC=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/IC_charity_0823.csv')
bad_IC_index=np.load(r'/bohr/HPLC-dataset-lfyc/v2/bad_IC.npy') #Some compounds that cannot get 3D conformer by RDKit
HPLC_IC=HPLC_IC.drop(bad_IC_index) #conformer error
all_smile_IC = HPLC_IC['SMILES'].values
T1_IC = HPLC_IC['RT'].values
Speed_IC = HPLC_IC['Speed'].values
Prop_IC = HPLC_IC['i-PrOH_proportion'].values
dataset_IC=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_IC_charity_0823.npy',allow_pickle=True).tolist()
index_IC=HPLC_IC['Unnamed: 0'].values
print("data_num:",len(dataset_IC))

if Use_column=='IA':
HPLC_IA=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/IA_charity_0823.csv')
bad_IA_index=np.load(r'/bohr/HPLC-dataset-lfyc/v2/bad_IA.npy')
HPLC_IA=HPLC_IA.drop(bad_IA_index) #conformer error
all_smile_IA = HPLC_IA['SMILES'].values
T1_IA = HPLC_IA['RT'].values
Speed_IA = HPLC_IA['Speed'].values
Prop_IA = HPLC_IA['i-PrOH_proportion'].values
dataset_IA=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_IA_charity_0823.npy',allow_pickle=True).tolist()
index_IA=HPLC_IA['Unnamed: 0'].values
print("data_num:",len(dataset_IA))

print("Dataset has been loaded!")


data_num: 4971
Dataset has been loaded!
代码
文本

构建训练、验证、测试数据集

代码
文本
[109]
#===========Construct dataset==============

if Use_column=='ADH':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_ADH,index_ADH, T1_ADH, Speed_ADH, Prop_ADH,column=Use_column)
if Use_column=='IC':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_IC,index_IC, T1_IC, Speed_IC, Prop_IC,column=Use_column)
if Use_column=='IA':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_IA,index_IA, T1_IA, Speed_IA, Prop_IA,column=Use_column)
if Use_column=='ODH':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_ODH,index_ODH, T1_ODH, Speed_ODH, Prop_ODH,column=Use_column)


total_num = len(dataset_graph_atom_bond)
print('data num:',total_num)

train_ratio = 0.90
validate_ratio = 0.05
test_ratio = 0.05
args = parse_args()
prepartion(args)
nn_params = {
'num_tasks': 3,
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling,
'descriptor_dim': 1827
}



#given random seed
if Use_column=='ODH':
np.random.seed(388)
if Use_column=='ADH':
np.random.seed(505)
if Use_column=='IC':
np.random.seed(526)
if Use_column=='IA':
np.random.seed(388)


# automatic dataloading and splitting
data_array = np.arange(0, total_num, 1)
np.random.shuffle(data_array)
torch.random.manual_seed(525)

train_data_atom_bond = []
valid_data_atom_bond = []
test_data_atom_bond = []
train_data_bond_angle = []
valid_data_bond_angle = []
test_data_bond_angle = []

train_num = int(len(data_array) * train_ratio)
test_num = int(len(data_array) * test_ratio)
val_num = int(len(data_array) * validate_ratio)

print("training data num:",train_num,"\n validating data num:",val_num,"\ntesting data num:",test_num)
train_index = data_array[0:train_num]
valid_index = data_array[train_num:train_num + val_num]
if test_mode == 'fixed':
test_index = data_array[total_num-test_num:]
if test_mode=='random':
test_index = data_array[train_num + val_num:train_num + val_num + test_num]


for i in test_index:
test_data_atom_bond.append(dataset_graph_atom_bond[i])
test_data_bond_angle.append(dataset_graph_bond_angle[i])
for i in valid_index:
valid_data_atom_bond.append(dataset_graph_atom_bond[i])
valid_data_bond_angle.append(dataset_graph_bond_angle[i])
for i in train_index:
train_data_atom_bond.append(dataset_graph_atom_bond[i])
train_data_bond_angle.append(dataset_graph_bond_angle[i])



train_loader_atom_bond = DataLoader(train_data_atom_bond, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
valid_loader_atom_bond = DataLoader(valid_data_atom_bond, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader_atom_bond = DataLoader(test_data_atom_bond, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
train_loader_bond_angle = DataLoader(train_data_bond_angle, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
valid_loader_bond_angle = DataLoader(valid_data_bond_angle, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader_bond_angle = DataLoader(test_data_bond_angle, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

device = args.device
criterion_fn = torch.nn.MSELoss()
model = GINGraphPooling(**nn_params).to(device)
num_params = sum(p.numel() for p in model.parameters())

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
print('===========Data Prepared================')
data num: 4942
training data num: 4447 
 validating data num: 247 
testing data num: 247
===========Data Prepared================
代码
文本

训练模型

代码
文本
[129]
def train(model, device, loader_atom_bond, loader_bond_angle, optimizer, criterion_fn):
model.train()
loss_accum = 0

for step, batch in enumerate(zip(loader_atom_bond,loader_bond_angle)):
batch_atom_bond=batch[0]
batch_bond_angle=batch[1]
batch_atom_bond = batch_atom_bond.to(device)
batch_bond_angle=batch_bond_angle.to(device)
pred = model(batch_atom_bond,batch_bond_angle)[0]#.view(-1, )
true=batch_atom_bond.y
optimizer.zero_grad()
loss=q_loss(0.1,true,pred[:,0])+torch.mean((true-pred[:,1])**2)+q_loss(0.9,true,pred[:,2])\
+torch.mean(torch.relu(pred[:,0]-pred[:,1]))+torch.mean(torch.relu(pred[:,1]-pred[:,2]))+torch.mean(torch.relu(2-pred))
loss.backward()
optimizer.step()
loss_accum += loss.detach().cpu().item()

return loss_accum / (step + 1)

dir_name = os.getcwd()+'\model_save'+ '/' + Use_column
try:
os.makedirs(dir_name)
except OSError:
pass

for epoch in tqdm(range(1500)):
train_mae = train(model, device, train_loader_atom_bond,train_loader_bond_angle, optimizer, criterion_fn)
if (epoch + 1) % 100 == 0:
valid_mae = eval(model, device, valid_loader_atom_bond,valid_loader_bond_angle)
print(train_mae, valid_mae)
torch.save(model.state_dict(), f'{dir_name}/model_save_{epoch + 1}.pth')
代码
文本

测试模型

代码
文本
[116]
def test(model, device, loader_atom_bond,loader_bond_angle):
model.eval()
y_pred = []
y_true = []
y_pred_10 = []
y_pred_90 = []
with torch.no_grad():
for _, batch in enumerate(zip(loader_atom_bond,loader_bond_angle)):
batch_atom_bond = batch[0]
batch_bond_angle = batch[1]
batch_atom_bond = batch_atom_bond.to(device)
batch_bond_angle = batch_bond_angle.to(device)
pred = model(batch_atom_bond,batch_bond_angle)[0]
y_true.append(batch_atom_bond.y.detach().cpu().reshape(-1,))
y_pred.append(pred[:, 1].detach().cpu())
y_pred_10.append(pred[:, 0].detach().cpu())
y_pred_90.append(pred[:, 2].detach().cpu())
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
y_pred_10 = torch.cat(y_pred_10, dim=0)
y_pred_90 = torch.cat(y_pred_90, dim=0)



R_square = 1 - (((y_true - y_pred) ** 2).sum() / ((y_true - y_pred.mean()) ** 2).sum())
test_mae=torch.mean((y_true - y_pred) ** 2)
print(R_square)
return y_pred, y_true,R_square,test_mae,y_pred_10,y_pred_90
# for k, v in torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ODH_388/model_save_1500.pth',map_location=torch.device('cpu')).items():
# print(k)
if Use_column=='ODH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ODH_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)


if Use_column=='ADH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ADH_505/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)

if Use_column=='IC':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IC_526/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IA':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IA_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
y_pred, y_true, R_square, test_mae,y_pred_10,y_pred_90 = test(model, device, test_loader_atom_bond, test_loader_bond_angle)
y_pred=y_pred.cpu().data.numpy()
y_true = y_true.cpu().data.numpy()
y_pred_10=y_pred_10.cpu().data.numpy()
y_pred_90=y_pred_90.cpu().data.numpy()
print('relative_error',np.sqrt(np.sum((y_true - y_pred) ** 2) / np.sum(y_true ** 2)))
print('MAE',np.mean(np.abs(y_true - y_pred) / y_true))
print('RMSE',np.sqrt(np.mean((y_true - y_pred) ** 2)))
R_square = 1 - (((y_true - y_pred) ** 2).sum() / ((y_true - y_true.mean()) ** 2).sum())
print(R_square)
plt.figure(1,figsize=(2.5,2.5),dpi=300)
plt.style.use('ggplot')
plt.scatter(y_true, y_pred, c='#8983BF',s=15,alpha=0.4)
plt.plot(np.arange(0, 60), np.arange(0, 60),linewidth=1.5,linestyle='--',color='black')
plt.yticks(np.arange(0,66,10),np.arange(0,66,10), size=8)
plt.xticks(np.arange(0,66,10),np.arange(0,66,10), size=8)
plt.xlabel('Observed data', size=8)
plt.ylabel('Predicted data', size=8)
plt.show()
tensor(0.7777)
relative_error 0.24293557
MAE 0.20141043
RMSE 3.918548
0.7776966840028763
代码
文本

判断是否能分离

代码
文本
[128]
def cal_prob(prediction):
'''
calculate the separation probability Sp
'''
#input prediction=[pred_1,pred_2]
#output: Sp
a=prediction[0][0]
b=prediction[1][0]
if a[2]<b[0]:
return 1
elif a[0]>b[2]:
return 1
else:
length=min(a[2],b[2])-max(a[0],b[0])
all=max(a[2],b[2])-min(a[0],b[0])
return 1-length/(all)

transfer_target='ODH'
draw_picture=True
smiles = ['CCCCC[C@H](F)C(=O)c1nccn1c2ccccc2','CCCCC[C@@H](F)C(=O)c1nccn1c2ccccc2']
y_pred=[]

speed = [1.0, 1.0]
eluent = [0.02, 0.02]

mols=[]
all_descriptor=[]
dataset=[]
for smile in smiles:
mol = Chem.MolFromSmiles(smile)
mols.append(mol)
from rdkit.Chem import Draw
if draw_picture==True:
index=0
for mol in mols:
smiles_pic = Draw.MolToImage(mol, size=(200, 100), kekulize=True)
plt.imshow(smiles_pic)
plt.axis('off')
index+=1
plt.show()
plt.clf()

for smile in smiles:
mol = obtain_3D_mol(smile, 'trail')
mol = Chem.MolFromMolFile(f"trail.mol")
all_descriptor.append(mord(mol))
dataset.append(mol_to_geognn_graph_data_MMFF3d(mol))


for i in range(0, len(dataset)):
data = dataset[i]
atom_feature = []
bond_feature = []
for name in atom_id_names:
atom_feature.append(data[name])
for name in bond_id_names:
bond_feature.append(data[name])
atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64)
bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64)
bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32))
bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32))
y = torch.Tensor([float(speed[i])])
edge_index = torch.from_numpy(data['edges'].T).to(torch.int64)
bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64)

prop = torch.ones([bond_feature.shape[0]]) * eluent[i]

TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][820] / 100
RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][821]
RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][822]
MDEC = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][1568]
MATS = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][457]

bond_feature = torch.cat([bond_feature, bond_float_feature.reshape(-1, 1)], dim=1)
bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1)

bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1)


data_atom_bond = Data(atom_feature, edge_index, bond_feature, y)
data_bond_angle = Data(edge_index=bond_index, edge_attr=bond_angle_feature)

if Use_column=='ODH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ODH_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='ADH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ADH_505/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IC':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IC_526/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IA':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IA_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
model.eval()

pred, h_graph = model(data_atom_bond.to(device),data_bond_angle.to(device))
y_pred.append(pred.detach().cpu().data.numpy()/speed[i])

print(y_pred)
print("separation probability:",cal_prob(y_pred))
[array([[10.237834, 10.40861 , 10.870056]], dtype=float32), array([[9.105642, 9.192403, 9.570723]], dtype=float32)]
separation probability: 1
<Figure size 640x480 with 0 Axes>
代码
文本
python
HPLC
pythonHPLC
点个赞吧
推荐阅读
公开
TLC_prediction
TLC
TLC
bohr6ef000
发布于 2024-05-09
公开
TLC_prediction
TLC
TLC
莫凡洋
更新于 2024-06-06
1 赞