新建
HPLC retention time prediction
bohr6ef000
推荐镜像 :mfy:02
推荐机型 :c2_m4_cpu
赞
数据集
HPLC_dataset(v2)
安装相关包
代码
文本
[101]
!pip install mordred
!pip install torch_geometric
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: mordred in /opt/conda/lib/python3.8/site-packages (1.2.0) Requirement already satisfied: networkx==2.* in /opt/conda/lib/python3.8/site-packages (from mordred) (2.8.8) Requirement already satisfied: six==1.* in /opt/conda/lib/python3.8/site-packages (from mordred) (1.16.0) Requirement already satisfied: numpy==1.* in /opt/conda/lib/python3.8/site-packages (from mordred) (1.22.4) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: torch_geometric in /opt/conda/lib/python3.8/site-packages (2.5.3) Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (4.64.1) Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (2.28.2) Requirement already satisfied: fsspec in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (2023.1.0) Requirement already satisfied: scipy in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (1.7.3) Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (1.22.4) Requirement already satisfied: aiohttp in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (3.8.4) Requirement already satisfied: pyparsing in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (3.0.9) Requirement already satisfied: psutil>=5.8.0 in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (5.9.0) Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (1.0.2) Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from torch_geometric) (3.1.2) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (4.0.2) Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (22.1.0) Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (1.3.1) Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (1.3.3) Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (6.0.4) Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (1.8.2) Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->torch_geometric) (3.3.2) Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->torch_geometric) (2.1.1) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->torch_geometric) (1.26.14) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->torch_geometric) (3.4) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->torch_geometric) (2022.12.7) Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->torch_geometric) (1.2.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->torch_geometric) (3.1.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[112]
import torch
from torch_geometric.nn import MessagePassing
from rdkit.Chem import Descriptors
from torch_geometric.data import Data
import argparse
import warnings
from rdkit.Chem.Descriptors import rdMolDescriptors
import pandas as pd
import os
from mordred import Calculator, descriptors, is_missing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdchem
import pandas as pd
from torch_geometric.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
代码
文本
定义函数以转化分子特征及将分子转化为图
代码
文本
[103]
device=torch.device("cpu")
DAY_LIGHT_FG_SMARTS_LIST = [
# C
"[CX4]",
"[$([CX2](=C)=C)]",
"[$([CX3]=[CX3])]",
"[$([CX2]#C)]",
# C & O
"[CX3]=[OX1]",
"[$([CX3]=[OX1]),$([CX3+]-[OX1-])]",
"[CX3](=[OX1])C",
"[OX1]=CN",
"[CX3](=[OX1])O",
"[CX3](=[OX1])[F,Cl,Br,I]",
"[CX3H1](=O)[#6]",
"[CX3](=[OX1])[OX2][CX3](=[OX1])",
"[NX3][CX3](=[OX1])[#6]",
"[NX3][CX3]=[NX3+]",
"[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]",
"[NX3][CX3](=[OX1])[OX2H0]",
"[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]",
"[CX3](=O)[O-]",
"[CX3](=[OX1])(O)O",
"[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]",
"C[OX2][CX3](=[OX1])[OX2]C",
"[CX3](=O)[OX2H1]",
"[CX3](=O)[OX1H0-,OX2H1]",
"[NX3][CX2]#[NX1]",
"[#6][CX3](=O)[OX2H0][#6]",
"[#6][CX3](=O)[#6]",
"[OD2]([#6])[#6]",
# H
"[H]",
"[!#1]",
"[H+]",
"[+H]",
"[!H]",
# N
"[NX3;H2,H1;!$(NC=O)]",
"[NX3][CX3]=[CX3]",
"[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
"[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]",
"[NX3][$(C=C),$(cc)]",
"[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]",
"[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]",
"[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]",
"[CH3X4]",
"[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]",
"[CH2X4][CX3](=[OX1])[NX3H2]",
"[CH2X4][CX3](=[OX1])[OH0-,OH]",
"[CH2X4][SX2H,SX1H0-]",
"[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]",
"[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]",
"[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\
[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1",
"[CHX4]([CH3X4])[CH2X4][CH3X4]",
"[CH2X4][CHX4]([CH3X4])[CH3X4]",
"[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]",
"[CH2X4][CH2X4][SX2][CH3X4]",
"[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1",
"[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]",
"[CH2X4][OX2H]",
"[NX3][CX3]=[SX1]",
"[CHX4]([CH3X4])[OX2H]",
"[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12",
"[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1",
"[CHX4]([CH3X4])[CH3X4]",
"N[CX4H2][CX3](=[OX1])[O,N]",
"N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]",
"[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]",
"[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]",
"[#7]",
"[NX2]=N",
"[NX2]=[NX2]",
"[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]",
"[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]",
"[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]",
"[NX3][NX3]",
"[NX3][NX2]=[*]",
"[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]",
"[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]",
"[NX3+]=[CX3]",
"[CX3](=[OX1])[NX3H][CX3](=[OX1])",
"[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])",
"[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])",
"[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]",
"[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]",
"[NX1]#[CX2]",
"[CX1-]#[NX2+]",
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
"[NX2]=[OX1]",
"[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]",
# O
"[OX2H]",
"[#6][OX2H]",
"[OX2H][CX3]=[OX1]",
"[OX2H]P",
"[OX2H][#6X3]=[#6]",
"[OX2H][cX3]:[c]",
"[OX2H][$(C=C),$(cc)]",
"[$([OH]-*=[!#6])]",
"[OX2,OX1-][OX2,OX1-]",
# P
"[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\
$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\
,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]",
"[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\
$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\
$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]",
# S
"[S-][CX3](=S)[#6]",
"[#6X3](=[SX1])([!N])[!N]",
"[SX2]",
"[#16X2H]",
"[#16!H0]",
"[#16X2H0]",
"[#16X2H0][!#16]",
"[#16X2H0][#16X2H0]",
"[#16X2H0][!#16].[#16X2H0][!#16]",
"[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]",
"[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]",
"[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]",
"[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]",
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]",
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]",
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]",
"[SX4](C)(C)(=O)=N",
"[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]",
"[$([#16X3]=[OX1]),$([#16X3+][OX1-])]",
"[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]",
"[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]",
"[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]",
"[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]",
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]",
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]",
"[#16X2][OX2H,OX1H0-]",
"[#16X2][OX2H0]",
# X
"[#6][F,Cl,Br,I]",
"[F,Cl,Br,I]",
"[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]",
]
def get_gasteiger_partial_charges(mol, n_iter=12):
"""
Calculates list of gasteiger partial charges for each atom in mol object.
Args:
mol: rdkit mol object.
n_iter(int): number of iterations. Default 12.
Returns:
list of computed partial charges for each atom.
"""
Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
throwOnParamFailure=True)
partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
mol.GetAtoms()]
return partial_charges
def create_standardized_mol_id(smiles):
"""
Args:
smiles: smiles sequence.
Returns:
inchi.
"""
if check_smiles_validity(smiles):
# remove stereochemistry
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
isomericSmiles=False)
mol = AllChem.MolFromSmiles(smiles)
if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
if '.' in smiles: # if multiple species, pick largest molecule
mol_species_list = split_rdkit_mol_obj(mol)
largest_mol = get_largest_mol(mol_species_list)
inchi = AllChem.MolToInchi(largest_mol)
else:
inchi = AllChem.MolToInchi(mol)
return inchi
else:
return
else:
return
def check_smiles_validity(smiles):
"""
Check whether the smile can't be converted to rdkit mol object.
"""
try:
m = Chem.MolFromSmiles(smiles)
if m:
return True
else:
return False
except Exception as e:
return False
def split_rdkit_mol_obj(mol):
"""
Split rdkit mol object containing multiple species or one species into a
list of mol objects or a list containing a single object respectively.
Args:
mol: rdkit mol object.
"""
smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
smiles_list = smiles.split('.')
mol_species_list = []
for s in smiles_list:
if check_smiles_validity(s):
mol_species_list.append(AllChem.MolFromSmiles(s))
return mol_species_list
def get_largest_mol(mol_list):
"""
Given a list of rdkit mol objects, returns mol object containing the
largest num of atoms. If multiple containing largest num of atoms,
picks the first one.
Args:
mol_list(list): a list of rdkit mol object.
Returns:
the largest mol.
"""
num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
return mol_list[largest_mol_idx]
def rdchem_enum_to_list(values):
"""values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER}
"""
return [values[i] for i in range(len(values))]
def safe_index(alist, elem):
"""
Return index of element e in list l. If e is not present, return the last index
"""
try:
return alist.index(elem)
except ValueError:
return len(alist) - 1
def get_atom_feature_dims(list_acquired_feature_names):
""" tbd
"""
return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names]))
def get_bond_feature_dims(list_acquired_feature_names):
""" tbd
"""
list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names]))
# +1 for self loop edges
return [_l + 1 for _l in list_bond_feat_dim]
class CompoundKit(object):
"""
CompoundKit
"""
atom_vocab_dict = {
"atomic_num": list(range(1, 119)) + ['misc'],
"chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values),
"degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
"explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'],
"formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
"hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values),
"implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'],
"is_aromatic": [0, 1],
"total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'num_radical_e': [0, 1, 2, 3, 4, 'misc'],
'atom_is_in_ring': [0, 1],
'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
}
bond_vocab_dict = {
"bond_dir": rdchem_enum_to_list(rdchem.BondDir.values),
"bond_type": rdchem_enum_to_list(rdchem.BondType.values),
"is_in_ring": [0, 1],
'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values),
'is_conjugated': [0, 1],
}
# float features
atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass']
# bond_float_feats= ["bond_length", "bond_angle"] # optional
### functional groups
day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST
day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list]
morgan_fp_N = 200
morgan2048_fp_N = 2048
maccs_fp_N = 167
period_table = Chem.GetPeriodicTable()
### atom
@staticmethod
def get_atom_value(atom, name):
"""get atom values"""
if name == 'atomic_num':
return atom.GetAtomicNum()
elif name == 'chiral_tag':
return atom.GetChiralTag()
elif name == 'degree':
return atom.GetDegree()
elif name == 'explicit_valence':
return atom.GetExplicitValence()
elif name == 'formal_charge':
return atom.GetFormalCharge()
elif name == 'hybridization':
return atom.GetHybridization()
elif name == 'implicit_valence':
return atom.GetImplicitValence()
elif name == 'is_aromatic':
return int(atom.GetIsAromatic())
elif name == 'mass':
return int(atom.GetMass())
elif name == 'total_numHs':
return atom.GetTotalNumHs()
elif name == 'num_radical_e':
return atom.GetNumRadicalElectrons()
elif name == 'atom_is_in_ring':
return int(atom.IsInRing())
elif name == 'valence_out_shell':
return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())
else:
raise ValueError(name)
@staticmethod
def get_atom_feature_id(atom, name):
"""get atom features id"""
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name))
@staticmethod
def get_atom_feature_size(name):
"""get atom features size"""
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
return len(CompoundKit.atom_vocab_dict[name])
### bond
@staticmethod
def get_bond_value(bond, name):
"""get bond values"""
if name == 'bond_dir':
return bond.GetBondDir()
elif name == 'bond_type':
return bond.GetBondType()
elif name == 'is_in_ring':
return int(bond.IsInRing())
elif name == 'is_conjugated':
return int(bond.GetIsConjugated())
elif name == 'bond_stereo':
return bond.GetStereo()
else:
raise ValueError(name)
@staticmethod
def get_bond_feature_id(bond, name):
"""get bond features id"""
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name))
@staticmethod
def get_bond_feature_size(name):
"""get bond features size"""
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
return len(CompoundKit.bond_vocab_dict[name])
### fingerprint
@staticmethod
def get_morgan_fingerprint(mol, radius=2):
"""get morgan fingerprint"""
nBits = CompoundKit.morgan_fp_N
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
return [int(b) for b in mfp.ToBitString()]
@staticmethod
def get_morgan2048_fingerprint(mol, radius=2):
"""get morgan2048 fingerprint"""
nBits = CompoundKit.morgan2048_fp_N
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
return [int(b) for b in mfp.ToBitString()]
@staticmethod
def get_maccs_fingerprint(mol):
"""get maccs fingerprint"""
fp = AllChem.GetMACCSKeysFingerprint(mol)
return [int(b) for b in fp.ToBitString()]
### functional groups
@staticmethod
def get_daylight_functional_group_counts(mol):
"""get daylight functional group counts"""
fg_counts = []
for fg_mol in CompoundKit.day_light_fg_mo_list:
sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True)
fg_counts.append(len(sub_structs))
return fg_counts
@staticmethod
def get_ring_size(mol):
"""return (N,6) list"""
rings = mol.GetRingInfo()
rings_info = []
for r in rings.AtomRings():
rings_info.append(r)
ring_list = []
for atom in mol.GetAtoms():
atom_result = []
for ringsize in range(3, 9):
num_of_ring_at_ringsize = 0
for r in rings_info:
if len(r) == ringsize and atom.GetIdx() in r:
num_of_ring_at_ringsize += 1
if num_of_ring_at_ringsize > 8:
num_of_ring_at_ringsize = 9
atom_result.append(num_of_ring_at_ringsize)
ring_list.append(atom_result)
return ring_list
@staticmethod
def atom_to_feat_vector(atom):
""" tbd """
atom_names = {
"atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()),
"chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()),
"degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()),
"explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()),
"formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()),
"hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()),
"implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()),
"is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())),
"total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()),
'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()),
'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())),
'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'],
CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())),
'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()),
'partial_charge': CompoundKit.check_partial_charge(atom),
'mass': atom.GetMass(),
}
return atom_names
@staticmethod
def get_atom_names(mol):
"""get atom name list
TODO: to be remove in the future
"""
atom_features_dicts = []
Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
for i, atom in enumerate(mol.GetAtoms()):
atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom))
ring_list = CompoundKit.get_ring_size(mol)
for i, atom in enumerate(mol.GetAtoms()):
atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0])
atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1])
atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2])
atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3])
atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4])
atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5])
return atom_features_dicts
@staticmethod
def check_partial_charge(atom):
"""tbd"""
pc = atom.GetDoubleProp('_GasteigerCharge')
if pc != pc:
# unsupported atom, replace nan with 0
pc = 0
if pc == float('inf'):
# max 4 for other atoms, set to 10 here if inf is get
pc = 10
return pc
class Compound3DKit(object):
"""the 3Dkit of Compound"""
@staticmethod
def get_atom_poses(mol, conf):
"""tbd"""
atom_poses = []
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomicNum() == 0:
return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms())
pos = conf.GetAtomPosition(i)
atom_poses.append([pos.x, pos.y, pos.z])
return atom_poses
@staticmethod
def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False):
"""the atoms of mol will be changed in some cases."""
conf = mol.GetConformer()
atom_poses = Compound3DKit.get_atom_poses(mol, conf)
return mol,atom_poses
# try:
# new_mol = Chem.AddHs(mol)
# res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
# ### MMFF generates multiple conformations
# res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
# new_mol = Chem.RemoveHs(new_mol)
# index = np.argmin([x[1] for x in res])
# energy = res[index][1]
# conf = new_mol.GetConformer(id=int(index))
# except:
# new_mol = mol
# AllChem.Compute2DCoords(new_mol)
# energy = 0
# conf = new_mol.GetConformer()
#
# atom_poses = Compound3DKit.get_atom_poses(new_mol, conf)
# if return_energy:
# return new_mol, atom_poses, energy
# else:
# return new_mol, atom_poses
@staticmethod
def get_2d_atom_poses(mol):
"""get 2d atom poses"""
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
atom_poses = Compound3DKit.get_atom_poses(mol, conf)
return atom_poses
@staticmethod
def get_bond_lengths(edges, atom_poses):
"""get bond lengths"""
bond_lengths = []
for src_node_i, tar_node_j in edges:
bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i]))
bond_lengths = np.array(bond_lengths, 'float32')
return bond_lengths
@staticmethod
def get_superedge_angles(edges, atom_poses, dir_type='HT'):
"""get superedge angles"""
def _get_vec(atom_poses, edge):
return atom_poses[edge[1]] - atom_poses[edge[0]]
def _get_angle(vec1, vec2):
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0
vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors
vec2 = vec2 / (norm2 + 1e-5)
angle = np.arccos(np.dot(vec1, vec2))
return angle
E = len(edges)
edge_indices = np.arange(E)
super_edges = []
bond_angles = []
bond_angle_dirs = []
for tar_edge_i in range(E):
tar_edge = edges[tar_edge_i]
if dir_type == 'HT':
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]]
elif dir_type == 'HH':
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]]
else:
raise ValueError(dir_type)
for src_edge_i in src_edge_indices:
if src_edge_i == tar_edge_i:
continue
src_edge = edges[src_edge_i]
src_vec = _get_vec(atom_poses, src_edge)
tar_vec = _get_vec(atom_poses, tar_edge)
super_edges.append([src_edge_i, tar_edge_i])
angle = _get_angle(src_vec, tar_vec)
bond_angles.append(angle)
bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T
if len(super_edges) == 0:
super_edges = np.zeros([0, 2], 'int64')
bond_angles = np.zeros([0, ], 'float32')
else:
super_edges = np.array(super_edges, 'int64')
bond_angles = np.array(bond_angles, 'float32')
return super_edges, bond_angles, bond_angle_dirs
def new_smiles_to_graph_data(smiles, **kwargs):
"""
Convert smiles to graph data.
"""
mol = AllChem.MolFromSmiles(smiles)
if mol is None:
return None
data = new_mol_to_graph_data(mol)
return data
def new_mol_to_graph_data(mol):
"""
mol_to_graph_data
Args:
atom_features: Atom features.
edge_features: Edge features.
morgan_fingerprint: Morgan fingerprint.
functional_groups: Functional groups.
"""
if len(mol.GetAtoms()) == 0:
return None
atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names
bond_id_names = list(CompoundKit.bond_vocab_dict.keys())
data = {}
### atom features
data = {name: [] for name in atom_id_names}
raw_atom_feat_dicts = CompoundKit.get_atom_names(mol)
for atom_feat in raw_atom_feat_dicts:
for name in atom_id_names:
data[name].append(atom_feat[name])
### bond and bond features
for name in bond_id_names:
data[name] = []
data['edges'] = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# i->j and j->i
data['edges'] += [(i, j), (j, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name)
data[name] += [bond_feature_id] * 2
#### self loop
N = len(data[atom_id_names[0]])
for i in range(N):
data['edges'] += [(i, i)]
for name in bond_id_names:
bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1
data[name] += [bond_feature_id] * N
### make ndarray and check length
for name in list(CompoundKit.atom_vocab_dict.keys()):
data[name] = np.array(data[name], 'int64')
for name in CompoundKit.atom_float_names:
data[name] = np.array(data[name], 'float32')
for name in bond_id_names:
data[name] = np.array(data[name], 'int64')
data['edges'] = np.array(data['edges'], 'int64')
### morgan fingerprint
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
return data
def mol_to_graph_data(mol):
"""
mol_to_graph_data
Args:
atom_features: Atom features.
edge_features: Edge features.
morgan_fingerprint: Morgan fingerprint.
functional_groups: Functional groups.
"""
if len(mol.GetAtoms()) == 0:
return None
atom_id_names = [
"atomic_num", "chiral_tag", "degree", "explicit_valence",
"formal_charge", "hybridization", "implicit_valence",
"is_aromatic", "total_numHs",
]
bond_id_names = [
"bond_dir", "bond_type", "is_in_ring",
]
data = {}
for name in atom_id_names:
data[name] = []
data['mass'] = []
for name in bond_id_names:
data[name] = []
data['edges'] = []
### atom features
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomicNum() == 0:
return None
for name in atom_id_names:
data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV
data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01)
### bond features
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# i->j and j->i
data['edges'] += [(i, j), (j, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV
data[name] += [bond_feature_id] * 2
### self loop (+2)
N = len(data[atom_id_names[0]])
for i in range(N):
data['edges'] += [(i, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop
data[name] += [bond_feature_id] * N
### check whether edge exists
if len(data['edges']) == 0: # mol has no bonds
for name in bond_id_names:
data[name] = np.zeros((0,), dtype="int64")
data['edges'] = np.zeros((0, 2), dtype="int64")
### make ndarray and check length
for name in atom_id_names:
data[name] = np.array(data[name], 'int64')
data['mass'] = np.array(data['mass'], 'float32')
for name in bond_id_names:
data[name] = np.array(data[name], 'int64')
data['edges'] = np.array(data['edges'], 'int64')
### morgan fingerprint
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
return data
def mol_to_geognn_graph_data(mol, atom_poses, dir_type):
"""
mol: rdkit molecule
dir_type: direction type for bond_angle grpah
"""
if len(mol.GetAtoms()) == 0:
return None
data = mol_to_graph_data(mol)
data['atom_pos'] = np.array(atom_poses, 'float32')
data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos'])
BondAngleGraph_edges, bond_angles, bond_angle_dirs = \
Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos'])
data['BondAngleGraph_edges'] = BondAngleGraph_edges
data['bond_angle'] = np.array(bond_angles, 'float32')
return data
def mol_to_geognn_graph_data_MMFF3d(mol):
"""tbd"""
if len(mol.GetAtoms()) <= 400:
mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10)
else:
atom_poses = Compound3DKit.get_2d_atom_poses(mol)
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
def mol_to_geognn_graph_data_raw3d(mol):
"""tbd"""
atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer())
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
def obtain_3D_mol(smiles,name):
mol = AllChem.MolFromSmiles(smiles)
new_mol = Chem.AddHs(mol)
res = AllChem.EmbedMultipleConfs(new_mol)
### MMFF generates multiple conformations
res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
new_mol = Chem.RemoveHs(new_mol)
Chem.MolToMolFile(new_mol, name+'.mol')
return new_mol
代码
文本
相关参数设置
代码
文本
[104]
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
warnings.filterwarnings('ignore')
Use_column_info=False
Use_geometry_enhanced=True #default:True
atom_id_names = [
"atomic_num", "chiral_tag", "degree", "explicit_valence",
"formal_charge", "hybridization", "implicit_valence",
"is_aromatic", "total_numHs",
]
bond_id_names = [
"bond_dir", "bond_type", "is_in_ring"]
condition_name=['silica_surface','replace_basis']
condition_float_name=['eluent','grain_radian']
if Use_geometry_enhanced==True:
bond_float_names = ["bond_length",'prop']
if Use_geometry_enhanced==False:
bond_float_names=['prop']
bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS']
column_specify={'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4],
'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9],
'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14],
'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19],
'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24]}
column_smile=['O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4']
column_name=['ADH','ODH','IC','IA','OJH','ASH','IC3','IE','ID','OD3', 'IB','AD','AD3',
'IF','OD','AS','OJ3','IG','AZ','IAH','OJ','ICH','OZ3','IF3','IAU']
full_atom_feature_dims = get_atom_feature_dims(atom_id_names)
full_bond_feature_dims = get_bond_feature_dims(bond_id_names)
if Use_column_info==True:
bond_id_names.extend(['coated', 'immobilized'])
bond_float_names.extend(['diameter'])
if Use_geometry_enhanced==True:
bond_angle_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS'])
else:
bond_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS'])
full_bond_feature_dims.extend([2,2])
calc = Calculator(descriptors, ignore_3D=False)
代码
文本
定义图神经网络
代码
文本
[105]
class AtomEncoder(torch.nn.Module):
"""该类用于对原子属性做嵌入。
记`N`为原子属性的维度,则原子属性表示为`[x1, x2, ..., xi, xN]`,其中任意的一维度`xi`都是类别型数据。full_atom_feature_dims[i]存储了原子属性`xi`的类别数量。
该类将任意的原子属性`[x1, x2, ..., xi, xN]`转换为原子的嵌入`x_embedding`(维度为emb_dim)。
"""
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim + 5, emb_dim) # 不同维度的属性用不同的Embedding方法
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:, i])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim + 5, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
return bond_embedding
class RBF(torch.nn.Module):
"""
Radial Basis Function
"""
def __init__(self, centers, gamma, dtype='float32'):
super(RBF, self).__init__()
self.centers = centers.reshape([1, -1])
self.gamma = gamma
def forward(self, x):
"""
Args:
x(tensor): (-1, 1).
Returns:
y(tensor): (-1, n_centers)
"""
x = x.reshape([-1, 1])
return torch.exp(-self.gamma * torch.square(x - self.centers))
class BondFloatRBF(torch.nn.Module):
"""
Bond Float Encoder using Radial Basis Functions
"""
def __init__(self, bond_float_names, embed_dim, rbf_params=None):
super(BondFloatRBF, self).__init__()
self.bond_float_names = bond_float_names
if rbf_params is None:
self.rbf_params = {
'bond_length': (nn.Parameter(torch.arange(0, 2, 0.1)), nn.Parameter(torch.Tensor([10.0]))),
# (centers, gamma)
'prop': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
# 'TPSA':(nn.Parameter(torch.arange(0, 100, 5).to(torch.float32)), nn.Parameter(torch.Tensor([5.0]))),
# 'RASA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
# 'RPSA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0])))
}
else:
self.rbf_params = rbf_params
self.linear_list = torch.nn.ModuleList()
self.rbf_list = torch.nn.ModuleList()
for name in self.bond_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers.to(device), gamma.to(device))
self.rbf_list.append(rbf)
linear = torch.nn.Linear(len(centers), embed_dim).to(device)
self.linear_list.append(linear)
def forward(self, bond_float_features):
"""
Args:
bond_float_features(dict of tensor): bond float features.
"""
out_embed = 0
for i, name in enumerate(self.bond_float_names):
x = bond_float_features[:, i].reshape(-1, 1)
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
return out_embed
class BondAngleFloatRBF(torch.nn.Module):
"""
Bond Angle Float Encoder using Radial Basis Functions
"""
def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None):
super(BondAngleFloatRBF, self).__init__()
self.bond_angle_float_names = bond_angle_float_names
if rbf_params is None:
self.rbf_params = {
'bond_angle': (nn.Parameter(torch.arange(0, torch.pi, 0.1)), nn.Parameter(torch.Tensor([10.0]))),
# (centers, gamma)
}
else:
self.rbf_params = rbf_params
self.linear_list = torch.nn.ModuleList()
self.rbf_list = torch.nn.ModuleList()
for name in self.bond_angle_float_names:
if name == 'bond_angle':
centers, gamma = self.rbf_params[name]
rbf = RBF(centers.to(device), gamma.to(device))
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim)
self.linear_list.append(linear)
else:
linear = nn.Linear(len(self.bond_angle_float_names) - 1, embed_dim)
self.linear_list.append(linear)
break
def forward(self, bond_angle_float_features):
"""
Args:
bond_angle_float_features(dict of tensor): bond angle float features.
"""
out_embed = 0
for i, name in enumerate(self.bond_angle_float_names):
if name == 'bond_angle':
x = bond_angle_float_features[:, i].reshape(-1, 1)
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
else:
x = bond_angle_float_features[:, 1:]
out_embed += self.linear_list[i](x)
break
return out_embed
class ConditionEmbeding(torch.nn.Module):
"""
Not used in single_column prediction
"""
def __init__(self, condition_names,condition_float_names, embed_dim, rbf_params=None):
super(ConditionEmbeding, self).__init__()
self.condition_names = condition_names
self.condition_float_names=condition_float_names
if rbf_params is None:
self.rbf_params = {
'eluent': (nn.Parameter(torch.arange(0,1,0.1)), nn.Parameter(torch.Tensor([10.0]))),
'grain_radian': (nn.Parameter(torch.arange(0,10,0.1)), nn.Parameter(torch.Tensor([10.0])))# (centers, gamma)
}
else:
self.rbf_params = rbf_params
self.linear_list = torch.nn.ModuleList()
self.rbf_list = torch.nn.ModuleList()
self.embedding_list=torch.nn.ModuleList()
for name in self.condition_float_names:
centers, gamma = self.rbf_params[name]
rbf = RBF(centers.to(device), gamma.to(device))
self.rbf_list.append(rbf)
linear = nn.Linear(len(centers), embed_dim).to(device)
self.linear_list.append(linear)
for name in self.condition_names:
if name=='silica_surface':
emb = torch.nn.Embedding(2 + 5, embed_dim).to(device)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.embedding_list.append(emb)
elif name=='replace_basis':
emb = torch.nn.Embedding(6 + 5, embed_dim).to(device)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.embedding_list.append(emb)
def forward(self, condition):
"""
Args:
bond_angle_float_features(dict of tensor): bond angle float features.
"""
out_embed = 0
for i, name in enumerate(self.condition_float_names):
x = condition[:,2*i+1]
rbf_x = self.rbf_list[i](x)
out_embed += self.linear_list[i](rbf_x)
for i, name in enumerate(self.condition_names):
x = self.embedding_list[i](condition[:,2*i].to(torch.int64))
out_embed += x
return out_embed
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr="add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
def forward(self, x, edge_index, edge_attr):
edge_embedding = edge_attr
out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
class GINNodeEmbedding(torch.nn.Module):
"""
Node embedding
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module
"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
self.bond_encoder=BondEncoder(emb_dim)
self.bond_float_encoder=BondFloatRBF(bond_float_names,emb_dim)
self.bond_angle_encoder=BondAngleFloatRBF(bond_angle_float_names,emb_dim)
self.condition_encoder=ConditionEmbeding(condition_name,condition_float_name,emb_dim) #Not used in single_column prediction
# List of GNNs
self.convs = torch.nn.ModuleList()
self.convs_bond_angle=torch.nn.ModuleList()
self.convs_bond_float=torch.nn.ModuleList()
self.convs_bond_embeding=torch.nn.ModuleList()
self.convs_angle_float=torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.batch_norms_ba = torch.nn.ModuleList()
self.convs_condition=torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.convs_bond_angle.append(GINConv(emb_dim))
self.convs_bond_embeding.append(BondEncoder(emb_dim))
self.convs_bond_float.append(BondFloatRBF(bond_float_names,emb_dim))
self.convs_angle_float.append(BondAngleFloatRBF(bond_angle_float_names,emb_dim))
self.convs_condition.append(ConditionEmbeding(condition_name,condition_float_name,emb_dim)) #Not used in single_column prediction
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
self.batch_norms_ba.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_atom_bond,batched_bond_angle):
x, edge_index, edge_attr = batched_atom_bond.x, batched_atom_bond.edge_index, batched_atom_bond.edge_attr
edge_index_ba,edge_attr_ba= batched_bond_angle.edge_index, batched_bond_angle.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入
h_list_ba=[self.bond_float_encoder(edge_attr[:,3:edge_attr.shape[1]+1].to(torch.float32))+self.bond_encoder(edge_attr[:,0:3].to(torch.int64))]
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer])
cur_h_ba=self.convs_bond_embeding[layer](edge_attr[:,0:3].to(torch.int64))+self.convs_bond_float[layer](edge_attr[:,3:edge_attr.shape[1]+1].to(torch.float32))
cur_angle_hidden=self.convs_angle_float[layer](edge_attr_ba)
h_ba=self.convs_bond_angle[layer](cur_h_ba, edge_index_ba, cur_angle_hidden)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h_ba = F.dropout(F.relu(h_ba), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_ba+=h_list_ba[layer]
h_list.append(h)
h_list_ba.append(h_ba)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
edge_representation = h_list_ba[-1]
elif self.JK == "sum":
node_representation = 0
edge_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
edge_representation += h_list_ba[layer]
return node_representation,edge_representation
class GINGraphPooling(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="attention",
descriptor_dim=1781):
"""GIN Graph Pooling Module
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.descriptor_dim=descriptor_dim
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
self.NN_descriptor = nn.Sequential(nn.Linear(self.descriptor_dim, self.emb_dim),
nn.Sigmoid(),
nn.Linear(self.emb_dim, self.emb_dim))
self.sigmoid = nn.Sigmoid()
def forward(self, batched_atom_bond,batched_bond_angle):
h_node,h_node_ba= self.gnn_node(batched_atom_bond,batched_bond_angle)
h_graph = self.pool(h_node, batched_atom_bond.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output,h_graph
else:
# At inference time, relu is applied to output to ensure positivity
return torch.clamp(output, min=0, max=1e8),h_graph
def mord(mol, nBits=1826, errors_as_zeros=True):
try:
result = calc(mol)
desc_list = [r if not is_missing(r) else 0 for r in result]
np_arr = np.array(desc_list)
return np_arr
except:
return np.NaN if not errors_as_zeros else np.zeros((nBits,), dtype=np.float32)
def load_3D_mol():
dir = 'mol_save/'
for root, dirs, files in os.walk(dir):
file_names = files
file_names.sort(key=lambda x: int(x[x.find('_') + 5:x.find(".")])) # 按照前面的数字字符排序
mol_save = []
for file_name in file_names:
mol_save.append(Chem.MolFromMolFile(dir + file_name))
return mol_save
def parse_args():
parser = argparse.ArgumentParser(description='Graph data miming with GNN')
parser.add_argument('--task_name', type=str, default='GINGraphPooling',
help='task name')
parser.add_argument('--device', type=str, default='cpu',
help='which gpu to use if any (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--emb_dim', type=int, default=128,
help='dimensionality of hidden units in GNNs (default: 256)')
parser.add_argument('--drop_ratio', type=float, default=0.,
help='dropout ratio (default: 0.)')
parser.add_argument('--save_test', action='store_true')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training (default: 512)')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train (default: 100)')
parser.add_argument('--weight_decay', type=float, default=0.00001,
help='weight decay')
parser.add_argument('--early_stop', type=int, default=10,
help='early stop (default: 10)')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
parser.add_argument('--dataset_root', type=str, default="dataset",
help='dataset root')
args = parser.parse_args(args=[])
return args
def calc_dragon_type_desc(mol):
compound_mol = mol
compound_MolWt = Descriptors.ExactMolWt(compound_mol)
compound_TPSA = Chem.rdMolDescriptors.CalcTPSA(compound_mol)
compound_nRotB = Descriptors.NumRotatableBonds(compound_mol) # Number of rotable bonds
compound_HBD = Descriptors.NumHDonors(compound_mol) # Number of H bond donors
compound_HBA = Descriptors.NumHAcceptors(compound_mol) # Number of H bond acceptors
compound_LogP = Descriptors.MolLogP(compound_mol) # LogP
return rdMolDescriptors.CalcAUTOCORR3D(mol) + rdMolDescriptors.CalcMORSE(mol) + \
rdMolDescriptors.CalcRDF(mol) + rdMolDescriptors.CalcWHIM(mol) + \
[compound_MolWt, compound_TPSA, compound_nRotB, compound_HBD, compound_HBA, compound_LogP]
def prepartion(args):
save_dir = os.path.join('saves', args.task_name)
args.save_dir = save_dir
os.makedirs(args.save_dir, exist_ok=True)
args.device = torch.device("cpu")
args.output_file = open(os.path.join(args.save_dir, 'output'), 'a')
print(args, file=args.output_file, flush=True)
def q_loss(q,y_true,y_pred):
e = (y_true-y_pred)
return torch.mean(torch.maximum(q*e, (q-1)*e))
代码
文本
构建图数据集的函数
代码
文本
[106]
def Construct_dataset(dataset,data_index, T1, speed, eluent,column):
graph_atom_bond = []
graph_bond_angle = []
big_index = []
if column=='ODH':
all_descriptor=np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_ODH_morder.npy')
if column=='ADH':
all_descriptor = np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_ADH_charity_morder_0606.npy')
if column=='IC':
all_descriptor=np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_IC_charity_morder_0823.npy')
if column == 'IA':
all_descriptor = np.load('/bohr/HPLC-dataset-lfyc/v2/dataset_IA_charity_morder_0823.npy')
for i in range(len(dataset)):
data = dataset[i]
atom_feature = []
bond_feature = []
for name in atom_id_names:
atom_feature.append(data[name])
for name in bond_id_names:
bond_feature.append(data[name])
atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64)
bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64)
bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32))
bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32))
y = torch.Tensor([float(T1[i]) * float(speed[i])])
edge_index = torch.from_numpy(data['edges'].T).to(torch.int64)
bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64)
data_index_int=torch.from_numpy(np.array(data_index[i])).to(torch.int64)
prop=torch.ones([bond_feature.shape[0]])*eluent[i]
TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820]/100
RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821]
RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822]
MDEC=torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568]
MATS=torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457]
bond_feature=torch.cat([bond_feature,bond_float_feature.reshape(-1,1)],dim=1)
bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1)
bond_angle_feature=bond_angle_feature.reshape(-1,1)
bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1)
if y[0]>60:
big_index.append(i)
continue
data_atom_bond = Data(atom_feature, edge_index, bond_feature, y,data_index=data_index_int)
data_bond_angle= Data(edge_index=bond_index, edge_attr=bond_angle_feature)
graph_atom_bond.append(data_atom_bond)
graph_bond_angle.append(data_bond_angle)
return graph_atom_bond,graph_bond_angle,big_index
代码
文本
设置控制程序的参数,可以通过修改Use_column改变训练的柱子模型(可选:ADH,ODH,IC,IA)
代码
文本
[107]
#============Parameter setting===============
test_mode='fixed' #fixed or random or enantiomer(extract enantimoers)
Use_column='ODH' #trail name
代码
文本
读取数据
代码
文本
[108]
#-------------load data----------------
'''
The Graph construction is prepared and saved beforehand to accelerate the process by the code:
for smile in smiles:
mol = obtain_3D_mol(smile, 'trail')
mol = Chem.MolFromMolFile(f"trail.mol")
all_descriptor.append(mord(mol))
dataset.append(mol_to_geognn_graph_data_MMFF3d(mol))
'''
if Use_column=='ODH':
HPLC_ODH=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/ODH_charity_0616.csv')
HPLC_ODH=HPLC_ODH.drop(4231) #conformer error
all_smile_ODH = HPLC_ODH['SMILES'].values
T1_ODH = HPLC_ODH['RT'].values
Speed_ODH = HPLC_ODH['Speed'].values
Prop_ODH = HPLC_ODH['i-PrOH_proportion'].values
dataset_ODH=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_ODH.npy',allow_pickle=True).tolist()
index_ODH=HPLC_ODH['Unnamed: 0'].values
print("data_num:",len(dataset_ODH))
if Use_column=='ADH':
HPLC_ADH=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/ADH_charity_0606.csv')
all_smile_ADH = HPLC_ADH['SMILES'].values
T1_ADH = HPLC_ADH['RT'].values
Speed_ADH = HPLC_ADH['Speed'].values
Prop_ADH = HPLC_ADH['i-PrOH_proportion'].values
dataset_ADH=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_ADH_charity_0606.npy',allow_pickle=True).tolist()
index_ADH=HPLC_ADH['Unnamed: 0'].values
print("data_num:",len(dataset_ADH))
if Use_column=='IC':
HPLC_IC=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/IC_charity_0823.csv')
bad_IC_index=np.load(r'/bohr/HPLC-dataset-lfyc/v2/bad_IC.npy') #Some compounds that cannot get 3D conformer by RDKit
HPLC_IC=HPLC_IC.drop(bad_IC_index) #conformer error
all_smile_IC = HPLC_IC['SMILES'].values
T1_IC = HPLC_IC['RT'].values
Speed_IC = HPLC_IC['Speed'].values
Prop_IC = HPLC_IC['i-PrOH_proportion'].values
dataset_IC=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_IC_charity_0823.npy',allow_pickle=True).tolist()
index_IC=HPLC_IC['Unnamed: 0'].values
print("data_num:",len(dataset_IC))
if Use_column=='IA':
HPLC_IA=pd.read_csv(r'/bohr/HPLC-dataset-lfyc/v2/IA_charity_0823.csv')
bad_IA_index=np.load(r'/bohr/HPLC-dataset-lfyc/v2/bad_IA.npy')
HPLC_IA=HPLC_IA.drop(bad_IA_index) #conformer error
all_smile_IA = HPLC_IA['SMILES'].values
T1_IA = HPLC_IA['RT'].values
Speed_IA = HPLC_IA['Speed'].values
Prop_IA = HPLC_IA['i-PrOH_proportion'].values
dataset_IA=np.load(r'/bohr/HPLC-dataset-lfyc/v2/dataset_IA_charity_0823.npy',allow_pickle=True).tolist()
index_IA=HPLC_IA['Unnamed: 0'].values
print("data_num:",len(dataset_IA))
print("Dataset has been loaded!")
data_num: 4971 Dataset has been loaded!
代码
文本
构建训练、验证、测试数据集
代码
文本
[109]
#===========Construct dataset==============
if Use_column=='ADH':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_ADH,index_ADH, T1_ADH, Speed_ADH, Prop_ADH,column=Use_column)
if Use_column=='IC':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_IC,index_IC, T1_IC, Speed_IC, Prop_IC,column=Use_column)
if Use_column=='IA':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_IA,index_IA, T1_IA, Speed_IA, Prop_IA,column=Use_column)
if Use_column=='ODH':
dataset_graph_atom_bond,dataset_graph_bond_angle,big_index = Construct_dataset(dataset_ODH,index_ODH, T1_ODH, Speed_ODH, Prop_ODH,column=Use_column)
total_num = len(dataset_graph_atom_bond)
print('data num:',total_num)
train_ratio = 0.90
validate_ratio = 0.05
test_ratio = 0.05
args = parse_args()
prepartion(args)
nn_params = {
'num_tasks': 3,
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling,
'descriptor_dim': 1827
}
#given random seed
if Use_column=='ODH':
np.random.seed(388)
if Use_column=='ADH':
np.random.seed(505)
if Use_column=='IC':
np.random.seed(526)
if Use_column=='IA':
np.random.seed(388)
# automatic dataloading and splitting
data_array = np.arange(0, total_num, 1)
np.random.shuffle(data_array)
torch.random.manual_seed(525)
train_data_atom_bond = []
valid_data_atom_bond = []
test_data_atom_bond = []
train_data_bond_angle = []
valid_data_bond_angle = []
test_data_bond_angle = []
train_num = int(len(data_array) * train_ratio)
test_num = int(len(data_array) * test_ratio)
val_num = int(len(data_array) * validate_ratio)
print("training data num:",train_num,"\n validating data num:",val_num,"\ntesting data num:",test_num)
train_index = data_array[0:train_num]
valid_index = data_array[train_num:train_num + val_num]
if test_mode == 'fixed':
test_index = data_array[total_num-test_num:]
if test_mode=='random':
test_index = data_array[train_num + val_num:train_num + val_num + test_num]
for i in test_index:
test_data_atom_bond.append(dataset_graph_atom_bond[i])
test_data_bond_angle.append(dataset_graph_bond_angle[i])
for i in valid_index:
valid_data_atom_bond.append(dataset_graph_atom_bond[i])
valid_data_bond_angle.append(dataset_graph_bond_angle[i])
for i in train_index:
train_data_atom_bond.append(dataset_graph_atom_bond[i])
train_data_bond_angle.append(dataset_graph_bond_angle[i])
train_loader_atom_bond = DataLoader(train_data_atom_bond, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
valid_loader_atom_bond = DataLoader(valid_data_atom_bond, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader_atom_bond = DataLoader(test_data_atom_bond, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
train_loader_bond_angle = DataLoader(train_data_bond_angle, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
valid_loader_bond_angle = DataLoader(valid_data_bond_angle, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader_bond_angle = DataLoader(test_data_bond_angle, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
device = args.device
criterion_fn = torch.nn.MSELoss()
model = GINGraphPooling(**nn_params).to(device)
num_params = sum(p.numel() for p in model.parameters())
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
print('===========Data Prepared================')
data num: 4942 training data num: 4447 validating data num: 247 testing data num: 247 ===========Data Prepared================
代码
文本
训练模型
代码
文本
[129]
def train(model, device, loader_atom_bond, loader_bond_angle, optimizer, criterion_fn):
model.train()
loss_accum = 0
for step, batch in enumerate(zip(loader_atom_bond,loader_bond_angle)):
batch_atom_bond=batch[0]
batch_bond_angle=batch[1]
batch_atom_bond = batch_atom_bond.to(device)
batch_bond_angle=batch_bond_angle.to(device)
pred = model(batch_atom_bond,batch_bond_angle)[0]#.view(-1, )
true=batch_atom_bond.y
optimizer.zero_grad()
loss=q_loss(0.1,true,pred[:,0])+torch.mean((true-pred[:,1])**2)+q_loss(0.9,true,pred[:,2])\
+torch.mean(torch.relu(pred[:,0]-pred[:,1]))+torch.mean(torch.relu(pred[:,1]-pred[:,2]))+torch.mean(torch.relu(2-pred))
loss.backward()
optimizer.step()
loss_accum += loss.detach().cpu().item()
return loss_accum / (step + 1)
dir_name = os.getcwd()+'\model_save'+ '/' + Use_column
try:
os.makedirs(dir_name)
except OSError:
pass
for epoch in tqdm(range(1500)):
train_mae = train(model, device, train_loader_atom_bond,train_loader_bond_angle, optimizer, criterion_fn)
if (epoch + 1) % 100 == 0:
valid_mae = eval(model, device, valid_loader_atom_bond,valid_loader_bond_angle)
print(train_mae, valid_mae)
torch.save(model.state_dict(), f'{dir_name}/model_save_{epoch + 1}.pth')
代码
文本
测试模型
代码
文本
[116]
def test(model, device, loader_atom_bond,loader_bond_angle):
model.eval()
y_pred = []
y_true = []
y_pred_10 = []
y_pred_90 = []
with torch.no_grad():
for _, batch in enumerate(zip(loader_atom_bond,loader_bond_angle)):
batch_atom_bond = batch[0]
batch_bond_angle = batch[1]
batch_atom_bond = batch_atom_bond.to(device)
batch_bond_angle = batch_bond_angle.to(device)
pred = model(batch_atom_bond,batch_bond_angle)[0]
y_true.append(batch_atom_bond.y.detach().cpu().reshape(-1,))
y_pred.append(pred[:, 1].detach().cpu())
y_pred_10.append(pred[:, 0].detach().cpu())
y_pred_90.append(pred[:, 2].detach().cpu())
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
y_pred_10 = torch.cat(y_pred_10, dim=0)
y_pred_90 = torch.cat(y_pred_90, dim=0)
R_square = 1 - (((y_true - y_pred) ** 2).sum() / ((y_true - y_pred.mean()) ** 2).sum())
test_mae=torch.mean((y_true - y_pred) ** 2)
print(R_square)
return y_pred, y_true,R_square,test_mae,y_pred_10,y_pred_90
# for k, v in torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ODH_388/model_save_1500.pth',map_location=torch.device('cpu')).items():
# print(k)
if Use_column=='ODH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ODH_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='ADH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ADH_505/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IC':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IC_526/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IA':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IA_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
y_pred, y_true, R_square, test_mae,y_pred_10,y_pred_90 = test(model, device, test_loader_atom_bond, test_loader_bond_angle)
y_pred=y_pred.cpu().data.numpy()
y_true = y_true.cpu().data.numpy()
y_pred_10=y_pred_10.cpu().data.numpy()
y_pred_90=y_pred_90.cpu().data.numpy()
print('relative_error',np.sqrt(np.sum((y_true - y_pred) ** 2) / np.sum(y_true ** 2)))
print('MAE',np.mean(np.abs(y_true - y_pred) / y_true))
print('RMSE',np.sqrt(np.mean((y_true - y_pred) ** 2)))
R_square = 1 - (((y_true - y_pred) ** 2).sum() / ((y_true - y_true.mean()) ** 2).sum())
print(R_square)
plt.figure(1,figsize=(2.5,2.5),dpi=300)
plt.style.use('ggplot')
plt.scatter(y_true, y_pred, c='#8983BF',s=15,alpha=0.4)
plt.plot(np.arange(0, 60), np.arange(0, 60),linewidth=1.5,linestyle='--',color='black')
plt.yticks(np.arange(0,66,10),np.arange(0,66,10), size=8)
plt.xticks(np.arange(0,66,10),np.arange(0,66,10), size=8)
plt.xlabel('Observed data', size=8)
plt.ylabel('Predicted data', size=8)
plt.show()
tensor(0.7777) relative_error 0.24293557 MAE 0.20141043 RMSE 3.918548 0.7776966840028763
代码
文本
判断是否能分离
代码
文本
[128]
def cal_prob(prediction):
'''
calculate the separation probability Sp
'''
#input prediction=[pred_1,pred_2]
#output: Sp
a=prediction[0][0]
b=prediction[1][0]
if a[2]<b[0]:
return 1
elif a[0]>b[2]:
return 1
else:
length=min(a[2],b[2])-max(a[0],b[0])
all=max(a[2],b[2])-min(a[0],b[0])
return 1-length/(all)
transfer_target='ODH'
draw_picture=True
smiles = ['CCCCC[C@H](F)C(=O)c1nccn1c2ccccc2','CCCCC[C@@H](F)C(=O)c1nccn1c2ccccc2']
y_pred=[]
speed = [1.0, 1.0]
eluent = [0.02, 0.02]
mols=[]
all_descriptor=[]
dataset=[]
for smile in smiles:
mol = Chem.MolFromSmiles(smile)
mols.append(mol)
from rdkit.Chem import Draw
if draw_picture==True:
index=0
for mol in mols:
smiles_pic = Draw.MolToImage(mol, size=(200, 100), kekulize=True)
plt.imshow(smiles_pic)
plt.axis('off')
index+=1
plt.show()
plt.clf()
for smile in smiles:
mol = obtain_3D_mol(smile, 'trail')
mol = Chem.MolFromMolFile(f"trail.mol")
all_descriptor.append(mord(mol))
dataset.append(mol_to_geognn_graph_data_MMFF3d(mol))
for i in range(0, len(dataset)):
data = dataset[i]
atom_feature = []
bond_feature = []
for name in atom_id_names:
atom_feature.append(data[name])
for name in bond_id_names:
bond_feature.append(data[name])
atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64)
bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64)
bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32))
bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32))
y = torch.Tensor([float(speed[i])])
edge_index = torch.from_numpy(data['edges'].T).to(torch.int64)
bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64)
prop = torch.ones([bond_feature.shape[0]]) * eluent[i]
TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][820] / 100
RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][821]
RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][822]
MDEC = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][1568]
MATS = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][457]
bond_feature = torch.cat([bond_feature, bond_float_feature.reshape(-1, 1)], dim=1)
bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1)
bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1)
data_atom_bond = Data(atom_feature, edge_index, bond_feature, y)
data_bond_angle = Data(edge_index=bond_index, edge_attr=bond_angle_feature)
if Use_column=='ODH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ODH_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='ADH':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_ADH_505/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IC':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IC_526/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
if Use_column=='IA':
checkpoint=torch.load(f'/bohr/HPLC-dataset-lfyc/v2/model_IA_388/model_save_1500.pth',map_location=torch.device('cpu'))
new_pth=model.state_dict() # 需要加载参数的模型
pretrained_dict={} # 用于保存公共具有的参数
for k,v in checkpoint.items():
for kk in new_pth.keys():
if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
model.eval()
pred, h_graph = model(data_atom_bond.to(device),data_bond_angle.to(device))
y_pred.append(pred.detach().cpu().data.numpy()/speed[i])
print(y_pred)
print("separation probability:",cal_prob(y_pred))
[array([[10.237834, 10.40861 , 10.870056]], dtype=float32), array([[9.105642, 9.192403, 9.570723]], dtype=float32)] separation probability: 1
<Figure size 640x480 with 0 Axes>
代码
文本
点个赞吧
推荐阅读
公开
TLC_predictionbohr6ef000
发布于 2024-05-09
公开
TLC_prediction莫凡洋
更新于 2024-06-06
1 赞