主动学习寻找最优输入
本文使用建模主动学习迭代方式,寻找多极值函数最小值处对应的输入X。 (类比问题为:通过建模寻找密度最小的复合材料,并确定两种成分的配比x0,x1)。 (类比问题为:通过建模寻找熔点最小的复合材料,并确定制备过程中,前后两种工艺的施加值x0,x1)。
点击任意Basic镜像开始。
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting uni_active Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5d/13/61540e590fef1b72624e0f07b0db9568e4f72274566ff4f944c41687b24d/uni_active-0.0.2.tar.gz (19 kB) Preparing metadata (setup.py) ... done Collecting openpyxl Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6a/94/a59521de836ef0da54aaf50da6c4da8fb4072fb3053fa71f052fd9399e7a/openpyxl-3.1.2-py2.py3-none-any.whl (249 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 250.0/250.0 kB 2.1 MB/s eta 0:00:00a 0:00:01 Requirement already satisfied: six in /opt/mamba/lib/python3.10/site-packages (from uni_active) (1.16.0) Requirement already satisfied: pandas in /opt/mamba/lib/python3.10/site-packages (from uni_active) (1.5.3) Collecting scikit-learn Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d0/0b/26ad95cf0b747be967b15fb71a06f5ac67aba0fd2f9cd174de6edefc4674/scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.8/10.8 MB 32.4 MB/s eta 0:00:0000:0100:01 Collecting mgetool Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8b/6e/f29aca3f566320ef522eb3355518660bd6db462f3e6934f1b80af92737cb/mgetool-0.0.65.tar.gz (57 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.5/57.5 kB 13.2 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Collecting et-xmlfile Downloading https://pypi.tuna.tsinghua.edu.cn/packages/96/c2/3dd434b0108730014f1b96fd286040dc3bcb70066346f7e01ec2ac95865f/et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB) Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from mgetool->uni_active) (1.24.2) Collecting sympy Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d2/05/e6600db80270777c4a64238a98d442f0fd07cc8915be2a1c16da7f2b9e74/sympy-1.12-py3-none-any.whl (5.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 37.2 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: scipy in /opt/mamba/lib/python3.10/site-packages (from mgetool->uni_active) (1.10.1) Collecting joblib Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/40/d551139c85db202f1f384ba8bcf96aca2f329440a844f924c8a0040b6d02/joblib-1.3.2-py3-none-any.whl (302 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 302.2/302.2 kB 42.7 MB/s eta 0:00:00 Collecting matplotlib Downloading https://pypi.tuna.tsinghua.edu.cn/packages/19/e5/a4ea514515f270224435c69359abb7a3d152ed31b9ee3ba5e63017461945/matplotlib-3.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.6/11.6 MB 34.7 MB/s eta 0:00:0000:0100:01 Collecting path Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/a2/70f98e6de9854fe7e63561bf4380e0d42608cbfaa5f9d9e854f87b504414/path-16.7.1-py3-none-any.whl (25 kB) Collecting seaborn Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/e5/83fcd7e9db036c179e0352bfcd20f81d728197a16f883e7b90307a88e65e/seaborn-0.13.0-py3-none-any.whl (294 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.6/294.6 kB 48.4 MB/s eta 0:00:00 Requirement already satisfied: requests in /opt/mamba/lib/python3.10/site-packages (from mgetool->uni_active) (2.28.1) Requirement already satisfied: tqdm in /opt/mamba/lib/python3.10/site-packages (from mgetool->uni_active) (4.64.1) Requirement already satisfied: python-dateutil>=2.8.1 in /opt/mamba/lib/python3.10/site-packages (from pandas->uni_active) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /opt/mamba/lib/python3.10/site-packages (from pandas->uni_active) (2022.7.1) Collecting threadpoolctl>=2.0.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/81/12/fd4dea011af9d69e1cad05c75f3f7202cdcbeac9b712eea58ca779a72865/threadpoolctl-3.2.0-py3-none-any.whl (15 kB) Collecting pillow>=8 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e5/b9/5c6ad3241f1ccca4b781dfeddbab2dac4480f95aedc351a0e60c9f4c8aa9/Pillow-10.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 38.8 MB/s eta 0:00:00a 0:00:01 Requirement already satisfied: packaging>=20.0 in /opt/mamba/lib/python3.10/site-packages (from matplotlib->mgetool->uni_active) (23.0) Collecting contourpy>=1.0.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/58/56/e2c43dcfa1f9c7db4d5e3d6f5134b24ed953f4e2133a4b12f0062148db58/contourpy-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (310 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 310.7/310.7 kB 46.7 MB/s eta 0:00:00 Collecting pyparsing>=2.3.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/39/92/8486ede85fcc088f1b3dba4ce92dd29d126fd96b0008ea213167940a2475/pyparsing-3.1.1-py3-none-any.whl (103 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.1/103.1 kB 32.9 MB/s eta 0:00:00 Collecting cycler>=0.10 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl (8.3 kB) Collecting fonttools>=4.22.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e4/c3/eed1e401d45461fd36f493ac375ebbc3d1cf2b3c2cf0979255dc807e477b/fonttools-4.45.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.6/4.6 MB 37.9 MB/s eta 0:00:0000:0100:01 Collecting kiwisolver>=1.3.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6f/40/4ab1fdb57fced80ce5903f04ae1aed7c1d5939dda4fd0c0aa526c12fe28a/kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 37.3 MB/s eta 0:00:0000:01 Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests->mgetool->uni_active) (3.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests->mgetool->uni_active) (1.26.11) Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests->mgetool->uni_active) (2022.9.24) Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests->mgetool->uni_active) (2.1.1) Collecting mpmath>=0.19 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl (536 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 kB 38.3 MB/s eta 0:00:00 Building wheels for collected packages: uni_active, mgetool Building wheel for uni_active (setup.py) ... done Created wheel for uni_active: filename=uni_active-0.0.2-py3-none-any.whl size=20920 sha256=9ea2aca7ff3acf8bc0714558f446b1bb3b5365ddf06ceca393872ac71333976c Stored in directory: /root/.cache/pip/wheels/56/31/83/c3f9c0da938eb5127c7ea5dba05252395e53cb25487b03b558 Building wheel for mgetool (setup.py) ... done Created wheel for mgetool: filename=mgetool-0.0.65-py3-none-any.whl size=63715 sha256=ebb38b1f7afd849d45f8237b39b390cbbcc9920adbe629af4e7f5ef8b0d19179 Stored in directory: /root/.cache/pip/wheels/e7/7c/54/ebb527494a08b7eff70a8e25808a6e8ef852dc7d67f721f520 Successfully built uni_active mgetool Installing collected packages: mpmath, threadpoolctl, sympy, pyparsing, pillow, path, kiwisolver, joblib, fonttools, et-xmlfile, cycler, contourpy, scikit-learn, openpyxl, matplotlib, seaborn, mgetool, uni_active Successfully installed contourpy-1.2.0 cycler-0.12.1 et-xmlfile-1.1.0 fonttools-4.45.1 joblib-1.3.2 kiwisolver-1.4.5 matplotlib-3.8.2 mgetool-0.0.65 mpmath-1.3.0 openpyxl-3.1.2 path-16.7.1 pillow-10.1.0 pyparsing-3.1.1 scikit-learn-1.3.2 seaborn-0.13.0 sympy-1.12 threadpoolctl-3.2.0 uni_active-0.0.2 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
问题剖析
为了理解此问题。首先,我们构建一个y=f(x1,x2)的多极值映射函数,作为真实数据的生成器。即研究问题的ground truth。
接下来,我们绘制函数的分布图片。可以看出,在整个定义区间,映射函数具有多个极值点。其中,最小值点、最大值点分别使用五角星、点标记。 (注意,此图为上帝视角,在实际问题中,我们是看不到、不理解问题的真正函数关系,我们只有散落在空间中的数据点)。 从图中可以看出,左上侧部分区域含有淡蓝色区域,而右下侧区域为多处紫色,说明在左侧区域,含有极小值点。 而最小值点出现在左下侧边缘区域。
现在为了加大优化难度,我们构建如下优化问题:
- 寻找最小值点,及其对应的输入X。
- 实测数据相对于目标函数有噪音。
- 停止条件为,当前输入X_i与最小值输入X_min的欧式距离小于0.2。
- 最小值点在初始数据参数范围外。
测试数据
现在,回到现实视角。在实际问题中,我们只有一些散落在空间中的初始数据点。初始数据集使用随机方式生成,构建代码及数据点在空间的分布位置如下。
uni-active 演示介绍
那如何使用uni-active进行主动学习优化呢?简单来说,就是依靠建模,预测空间样本点的均值与方差值,从中挑选更优样本进行迭代,在优化模型的同时,也能够使得样本优化。称之为主动学习期望提升。
首先,我们演示第一个循环的内容。
简单地做数据处理
进行模型选择和调参
建立Dataset
注册 sklearn 模型
定义搜索网格
将数据集、模型导入循环函数ActiveLoopLocalTable
运行模型
保存结果
在下面代码中,我们使用了期望提升expected improvement (EI) 指标作为性能筛选指标,如对于最大化问题地指标为:
)
其中 预测空间网格的目标量预测值均值, 预测空间网格的目标量预测值方差, 是训练集目标量的最大值。
(D.Z. Xue, P.V. Balachandran, J. Hogden, J. Theiler, D.Q. Xue, T. Lookman, Accelerated search for materials with targeted properties by adaptive design, Nat. Commun. 7 (2016) 11241)
INFO -> >>>> START NEXT LOOP >>>> INFO -> Make grid with shape (10000, 2). INFO -> Start training model with resample method (20 times) and get predict value. Train 0.8711990874526474 Test 0.7879031177427134 100%|██████████| 20/20 [00:00<00:00, 29.63it/s] INFO -> Calculate mean and std (20 times) of predict value to evaluate robustness. INFO -> Calculate Expectation Improvement (EI) and select top 3 index. INFO -> Save new samples x to 'train-2'. INFO -> Save model to pickle file: smp1/model_pkl/o19_cirr.pth. INFO -> Save model info in json: smp1/model.json, with key: 'tmp_model-1'. INFO -> Save samples to: smp1/samples.xlsx. INFO -> Now, Please check the 'train-1' in smp1/samples.xlsx, and make experimental evaluation. INFO -> After the fill missing value, start the next cycle.
第一次循环完成,我们可以看到在目标文件夹下,生成了三个文件/文件夹:
- xlsx文件,用来存放数据信息。
- model.json文件,存放模型信息。
- model_pkl文件夹,存放序列化模型。
我们导入xlsx文件,查看数据信息。在经过第一次循环后,xlsx文件有四个sheet表格。
dict_keys(['train-0', 'train-1', 'test-0', 'targets'])
0 | 1 | 2 | |
---|---|---|---|
0 | 3.5 | 9.5 | 0 |
1 | 3.5 | 9.6 | 0 |
2 | 3.5 | 9.4 | 0 |
其中 'targets'用来存储目标值索引,'train-0'为初始训练数据,'test-0' 为初始测试数据。 上述三个比较好理解,而'train-1'为新增加待观测样本。在其刚生成时候,目标值列统一填充为0或者inf。即,此列通常需要进行人工修改为观测值。
这也是上述提示您 “INFO -> After the fill missing value, start the next cycle.” 的原因。
如果表格修改完毕,便可以进行第二轮迭代。如下所示:
INFO -> Load samples from: smp1/samples.xlsx. INFO -> Load model from pickle file: smp1/model_pkl/o19_cirr.pth. INFO -> Load model info from json: smp1/model.json, with key 'tmp_model-1'. INFO -> >>>> START NEXT LOOP >>>> INFO -> Make grid with shape (10000, 2). INFO -> Start training model with resample method (20 times) and get predict value. 100%|██████████| 20/20 [00:00<00:00, 27.54it/s] INFO -> Calculate mean and std (20 times) of predict value to evaluate robustness. INFO -> Calculate Expectation Improvement (EI) and select top 3 index. INFO -> Save new samples x to 'train-3'. INFO -> Save model to pickle file: smp1/model_pkl/_l1ub2yb.pth. INFO -> Save model info in json: smp1/model.json, with key: 'tmp_model-2'. INFO -> Save samples to: smp1/samples.xlsx. INFO -> Now, Please check the 'train-2' in smp1/samples.xlsx, and make experimental evaluation. INFO -> After the fill missing value, start the next cycle.
是不是代码明显感觉简单很多? 其实在第一轮过程中,主要的代码工作都在数据划分与建模注册阶段。 第二轮没有第一轮的Dataset与register模型过程。因为除了第一次使用外,后续迭代步骤直接本地导入上一轮存储的结果。 对应的细节区别是:第一轮用到了from_code函数,代表从代码构建,第二轮用到了from_path函数,代表从本地构建。
dict_keys(['train-0', 'train-1', 'train-2', 'test-0', 'targets'])
0 | 1 | 2 | |
---|---|---|---|
0 | 8.4 | 9.9 | -1067.454484 |
1 | 8.3 | 9.9 | -1093.503284 |
2 | 8.5 | 9.9 | -1035.216066 |
在新一轮的生成文件中,'train-2'为新增加待观测样本。即其目标列需要进行人工修改为观测值。后续循环可以以此类推,重复直至满足要求。
uni-active 主动学习优化
Ok,我们已经理解了uni-active的使用逻辑。接下来,我们整体进行uni-active 对本文问题进行优化过程。
由于修改表格的方式进行反馈的方式,效率比较低,所以我们直接使用func_noise函数进行产生结果,代表人工进行反馈。
这一部分是最基本的机器学习建模,在此不再赘述。通常来说,用户可以在另一个.py文件中,预先进行数据模型测试,找到一个比较满意的模型。
Train 0.8711984617884104 Test 0.7879056701931623
整个优化过程的主文件代码如下所示。
INFO -> >>>> START NEXT LOOP >>>> INFO -> Make grid with shape (10000, 2). INFO -> Start training model with resample method (20 times) and get predict value. 100%|██████████| 20/20 [00:00<00:00, 29.94it/s] INFO -> Calculate mean and std (20 times) of predict value to evaluate robustness. INFO -> Calculate Expectation Improvement (EI) and select top 3 index. INFO -> Save new samples x to 'train-2'. INFO -> Save model to pickle file: smp1/model_pkl/b26d3wph.pth. INFO -> Save model info in json: smp1/model.json, with key: 'tmp_model-1'. INFO -> Save samples to: smp1/samples.xlsx. INFO -> Now, Please check the 'train-1' in smp1/samples.xlsx, and make experimental evaluation. INFO -> After the fill missing value, start the next cycle. INFO -> Load samples from: smp1/samples.xlsx. INFO -> Load model from pickle file: smp1/model_pkl/_l1ub2yb.pth. INFO -> Load model info from json: smp1/model.json, with key 'tmp_model-2'. INFO -> >>>> START NEXT LOOP >>>> INFO -> Make grid with shape (10000, 2). INFO -> Start training model with resample method (20 times) and get predict value. 100%|██████████| 20/20 [00:00<00:00, 28.74it/s] INFO -> Calculate mean and std (20 times) of predict value to evaluate robustness. INFO -> Calculate Expectation Improvement (EI) and select top 3 index. INFO -> Save new samples x to 'train-3'. INFO -> Save model to pickle file: smp1/model_pkl/a2vn2_rm.pth. INFO -> Save model info in json: smp1/model.json, with key: 'tmp_model-3'. INFO -> Save samples to: smp1/samples.xlsx. INFO -> Now, Please check the 'train-2' in smp1/samples.xlsx, and make experimental evaluation. INFO -> After the fill missing value, start the next cycle. INFO -> Load samples from: smp1/samples.xlsx. INFO -> Load model from pickle file: smp1/model_pkl/a2vn2_rm.pth. INFO -> Load model info from json: smp1/model.json, with key 'tmp_model-3'. INFO -> >>>> START NEXT LOOP >>>> INFO -> Make grid with shape (10000, 2). INFO -> Start training model with resample method (20 times) and get predict value. 100%|██████████| 20/20 [00:00<00:00, 27.94it/s] INFO -> Calculate mean and std (20 times) of predict value to evaluate robustness. INFO -> Calculate Expectation Improvement (EI) and select top 3 index. INFO -> Save new samples x to 'train-4'. INFO -> Save model to pickle file: smp1/model_pkl/iy1w0kit.pth. INFO -> Save model info in json: smp1/model.json, with key: 'tmp_model-4'. INFO -> Save samples to: smp1/samples.xlsx. INFO -> Now, Please check the 'train-3' in smp1/samples.xlsx, and make experimental evaluation. INFO -> After the fill missing value, start the next cycle. End: loop=2
运行上述代码后,可以看到,经过loop=(0,1,2)三次循环后,优化停止,每部过程会有详细的输入输出日志。下面是优化过程的图形展示。
上图中1,2,3分别代表第几批采样数据。
考虑一下,上图中,为什么第1批数据出现在极值点的右侧? 这是因为由于初始样本的分布原因,右下角并没有数据值覆盖,模型对于此定义域是未知的。因此,此三角预测数据方差(不确定度)很大,采样算法挑选此处,更加具有探索价值。在1,2处数据点加入到模型后,模型学到了该侧的数据分布并不是沿着横向方向递减,而是具有波谷,因此第3步开始采样中心区域,找到极小值点。
所以,总结问题及效果如下:在100个数据样本训练测试模型,通过三轮迭代,每次增加三个新样本,共计使用109个样本。找到当前定义域的极小值点。
总结:
对于成分优化,工艺优化问题来说,找到比当前更好的X输入即为满足要求。因此,上述算法在实际应用中发挥很好的效果。
当然,必须承认,随着数据点减少,目前的采样算法可能陷入局部最优的现象。且其容易在已知区间探索,而不探索更多位置空间。加入模拟退火等算法可能解决上述问题,也带来收敛速度的变慢等问题,未来更多的采样算法加入希望能够解决 易在已知区间探索,而不探索更多区间的问题。