Delta Machine Learning在量化计算领域的应用(论文复现)
©️ Copyright 2023 @ Authors
作者:
金昱丞 📨
日期:2023-11-07
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
快速开始:点击上方的 开始连接 按钮,稍等片刻即可运行。如您遇到任何问题,请联系 bohrium@dp.tech 。
背景
本notebook是写框架的副产物,供图一乐。
在量子化学计算领域,不同的计算方法产生的计算结果之间通常存在精度差异。同时,不同精度的计算方法消耗的计算资源差距非常大。高精度量化计算随着体系规模的扩大,其计算成本往往成指数级增长,大批量获取高精度量化计算结果的成本是难以接受的。另一方面,较低精度的数据可以通过低精度的量化计算方法快速获取。如果能够获得不同数据间的残差(delta),就能以低精度方法的计算成本达到较高的数据精度,从而实现计算量和计算精度之间的平衡。
Delta Machine Learning 被用于解决这一问题,它的思路非常简单粗暴,即: 通过机器学习模型学习数据A和数据B之间的残差(delta),从而实现数据A到数据B的转化 。
Raghunathan Ramakrishnan 等人最早将这种方法用于辅助量化计算: Big Data Meets Quantum Chemistry Approximations: The Δ-Machine Learning Approach
论文复现
接下来,我们尝试复现一篇基于 Delta Machine Learning 的思路矫正气相小分子核磁位移的论文,熟悉一下Delta Machine Learning的具体流程。
论文: Computation of CCSD(T)-Quality NMR Chemical Shifts via Δ-Machine Learning from DFT
核磁共振光谱在确定不同化学分子的结构方面起着重要作用。在论文中,作者通过使用来自计算的输入特征和CCSD(T)/pcSseg-2理论水平的高精度参考数据,对DFT计算的核磁共振化学位移进行修正。该模型在一个包含1000个经过优化和几何失真的小有机分子结构的数据集上进行训练,这些结构包含了前三周期内的大多数元素,并包含了7090个1H和4230个13C核磁共振化学位移的数据。
观察上图可以发现,在经过ML矫正后的核磁位移比矫正前的baseline(DFT精度的计算结果)明显更接近高精度的targetline数据(CCSD(T)精度的计算结果)。
接下来,我们尝试用论文中的方法(M)复现这一结果。
数据准备
进行 Delta Machine Learning 训练前,需要通过计算准备不同精度的数据对,以及训练所需的相关描述符。
本notebook中为了节约大家的时间,相关的数据和代码已经整理完毕,运行以下代码块下载即可~
感兴趣的读者也可以前往文章的github仓库查看原始代码和数据:原始数据点这里!
Cloning into 'dpexample'... remote: Enumerating objects: 146, done. remote: Counting objects: 100% (146/146), done. remote: Compressing objects: 100% (109/109), done. remote: Total 146 (delta 55), reused 101 (delta 30), pack-reused 0 Receiving objects: 100% (146/146), 1.70 MiB | 1.85 MiB/s, done. Resolving deltas: 100% (55/55), done. /personal/dpexample Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple, https://pypi.ngc.nvidia.com Requirement already satisfied: ase in /opt/conda/lib/python3.8/site-packages (3.22.1) Requirement already satisfied: matplotlib>=3.1.0 in /opt/conda/lib/python3.8/site-packages (from ase) (3.7.1) Requirement already satisfied: scipy>=1.1.0 in /opt/conda/lib/python3.8/site-packages (from ase) (1.6.3) Requirement already satisfied: numpy>=1.15.0 in /opt/conda/lib/python3.8/site-packages (from ase) (1.20.3) Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (1.4.4) Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (1.0.5) Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (23.0) Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (4.32.0) Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (9.4.0) Requirement already satisfied: importlib-resources>=3.2.0 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (5.7.0) Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (0.11.0) Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (3.0.9) Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.8/site-packages (from matplotlib>=3.1.0->ase) (2.8.2) Requirement already satisfied: zipp>=3.1.0 in /opt/conda/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib>=3.1.0->ase) (3.11.0) Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib>=3.1.0->ase) (1.16.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
在进行训练之前,我们可以先查看一下数据,随便看一眼未经矫正前的误差分布和均方误差。
The MSE between DFT and CCSD(T): 58.915871391447126
数据集:
target: shift_high-low
feature : other line (shift_low + 基于人工经验提取的描述符)
Unnamed: 0 | filepath | compound | structure | atom | shift_high-low | shift_low | CN(X) | no_CH | no_CC | ... | orb_stdev_mull_p | orb_charge_loew_s | orb_charge_loew_p | orb_charge_loew_d | orb_stdev_loew_p | BO_loew_sum | BO_loew_av | BO_mayer_sum | BO_mayer_av | mayer_VA | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | /personal/dpexample/structures/orca_xyz_format... | 1 | 0 | 1 | -2.180973 | 22.6165 | 4.014655 | 3 | 1 | ... | 0.043487 | 2.748962 | 3.104530 | 0.192308 | 0.039575 | 3.9580 | 0.989500 | 3.9565 | 0.989125 | 3.9383 |
1 | 1 | /personal/dpexample/structures/orca_xyz_format... | 1 | 0 | 2 | -16.932256 | 157.8295 | 3.063921 | 1 | 2 | ... | 0.032790 | 2.799471 | 2.942899 | 0.239653 | 0.108145 | 3.8911 | 1.297033 | 3.8048 | 1.268267 | 3.8265 |
2 | 2 | /personal/dpexample/structures/orca_xyz_format... | 1 | 0 | 6 | -8.388336 | 121.0125 | 3.058490 | 1 | 2 | ... | 0.065215 | 2.779212 | 3.021192 | 0.251104 | 0.058710 | 4.0176 | 1.339200 | 3.9258 | 1.308600 | 3.8465 |
3 | 3 | /personal/dpexample/structures/orca_xyz_format... | 1 | 0 | 8 | -6.655499 | 88.2485 | 2.062114 | 0 | 2 | ... | 0.044227 | 2.837303 | 2.981360 | 0.206032 | 0.114460 | 4.1349 | 2.067450 | 3.8184 | 1.909200 | 3.6900 |
4 | 4 | /personal/dpexample/structures/orca_xyz_format... | 1 | 0 | 10 | -12.745342 | 86.6555 | 2.018988 | 1 | 1 | ... | 0.039510 | 2.873229 | 3.027809 | 0.110225 | 0.093518 | 3.8797 | 1.939850 | 3.8058 | 1.902900 | 3.8354 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
4225 | 4225 | /personal/dpexample/structures/orca_xyz_format... | 100 | 8 | 10 | -8.188055 | 220.2295 | 2.071565 | 0 | 1 | ... | 0.074693 | 2.832574 | 2.660929 | 0.430527 | 0.093575 | 4.5231 | 2.261550 | 3.8096 | 1.904800 | 3.8790 |
4226 | 4226 | /personal/dpexample/structures/orca_xyz_format... | 100 | 9 | 1 | -3.200289 | 30.3455 | 3.951633 | 3 | 1 | ... | 0.097381 | 2.807235 | 3.062093 | 0.160188 | 0.073808 | 3.8270 | 0.956750 | 3.8920 | 0.973000 | 3.9355 |
4227 | 4227 | /personal/dpexample/structures/orca_xyz_format... | 100 | 9 | 2 | -2.161913 | 36.3955 | 3.097138 | 0 | 3 | ... | 0.143529 | 2.775720 | 3.119263 | 0.262177 | 0.039218 | 3.9773 | 1.325767 | 3.6360 | 1.212000 | 3.6587 |
4228 | 4228 | /personal/dpexample/structures/orca_xyz_format... | 100 | 9 | 6 | -6.585249 | 57.0565 | 4.057648 | 2 | 1 | ... | 0.131063 | 2.739286 | 3.004378 | 0.288687 | 0.111243 | 4.3560 | 1.089000 | 4.0161 | 1.004025 | 4.0613 |
4229 | 4229 | /personal/dpexample/structures/orca_xyz_format... | 100 | 9 | 10 | -6.377227 | 202.7575 | 2.123856 | 0 | 1 | ... | 0.062545 | 2.800336 | 2.673355 | 0.472591 | 0.071078 | 4.5793 | 2.289650 | 3.8207 | 1.910350 | 3.9001 |
4230 rows × 37 columns
训练配置
简单跑个MLP来看看效果吧。
Init_data in ml_pbe0_pcSseg-2_c.csv Processing Data ... Loading Data ... 2024-04-28 14:55:31 | dpexample/Unimol_2_NMR_fix/task/trainer.py | 115 | INFO | Echem | Epoch [1/1000] train_loss: 37.9923, val_loss: 9.3364, val_mse: 9.5038, lr: 0.000001, 3.3s 2024-04-28 14:55:32 | dpexample/Unimol_2_NMR_fix/task/trainer.py | 115 | INFO | Echem | Epoch [2/1000] train_loss: 7.0009, val_loss: 5.5644, val_mse: 5.6167, lr: 0.000002, 1.2s 2024-04-28 14:55:33 | dpexample/Unimol_2_NMR_fix/task/trainer.py | 115 | INFO | Echem | Epoch [3/1000] train_loss: 5.4782, val_loss: 4.5383, val_mse: 4.5567, lr: 0.000003, 1.0s 2024-04-28 14:55:34 | dpexample/Unimol_2_NMR_fix/task/trainer.py | 115 | INFO | Echem | Epoch [4/1000] train_loss: 5.0454, val_loss: 4.2843, val_mse: 4.2922, lr: 0.000004, 0.9s Train: 45%|████▌ | 52/115 [00:00<00:00, 115.73it/s, Epoch=Epoch 99/1000, loss=1.6881, lr=0.0001]
MSE从53降到了个位数,看起来效果不错~
看到这里,大多数读者应该明白一个 Delta Machine Learning 工作是如何进行的了。本质上和其他ML工作没什么不同,就是一个 构建数据对 → 训练 的过程。
当然如果到这里就结束的话,这个notebook可能有点太水了,所以我得想办法再水几章 ,所以我们接下来再看看如何操作可以提升这类工作的效果。
提升效果
与其他构性关系分析问题一样,Delta Machine Learning 的效果同样依赖描述符的质量。当描述符质量足够好时,甚至可以不需要基线数据的辅助,单独预测出参考线的数据。我们可以试试用无敌Unimol来处理这个问题,看看效果如何~
在下面这个代码中,我们去除分子结构外的其他feature,以finetune模式运行unimol,提取分子和原子级别的描述符。不改变其他设置,看看效果。运行时间可能有点长~
看起来Unimol的效果不错呢,在此基础上还有可能取得更好的效果吗?
本着信息越充分训练效果越好的思想,我们将人工提取的描述符一起并入finetune训练,看看会不会取得更好的矫正效果
效果还是比较明显的(?)
可以预见的是,当feature包含的信息更加丰富时,矫正的效果会更好。
当然也可以尝试通过调参提升训练的效果,比如把最终MLP层的输入保存成csv文件,扔进autogluon就可以取得更好的效果(乐)。
Linfeng Zhang