AI+电芯 | 基于Enhanced Gaussian Process Dynamical Model的电池老化曲线预测
背景
预测电动汽车电池的使用寿命或剩余使用寿命是一个关键且具有挑战性的问题,近年来主要使用机器学习来预测重复循环期间健康状态(SOH)的演变。 为了提高预测估计的准确性,特别是在电池寿命的早期,许多算法结合了电池管理系统收集的数据中可用的特征。 除非使用多个电池数据集来直接预测寿命终止(这对于大致估计很有用),否则这种方法是不可行的,因为这些特征对于未来的循环是未知的。 在本文中,作者开发了一种高精度方法,通过使用 改进的高斯过程动态模型(GPDM) 来克服这一限制。 作者引入了 GPDM 的内核化版本,以在可观察坐标和潜在坐标之间提供更具表现力的协方差结构。 我们将该方法与迁移学习相结合,以跟踪未来的健康状况直至生命终结。 该方法可以将特征合并为不同的物理可观测值,而不要求它们的值超出数据可用的时间。 迁移学习用于使用来自类似电池的数据来改进超参数的学习。 该方法相对于当下的baseline算法(包括高斯过程模型以及深度卷积和循环网络)的准确性和优越性在三个数据集上得到了证明。
文章使用了NASA和OXford的公开数据集对模型进行训练和预测。这两个数据集都记录了锂离子电池的充电和放电性能,适合评估健康状态预测算法。
本Notebook搬运自github/PericlesHat/enhanced-GPDM,引自用文章Enhanced Gaussian Process Dynamical Models with Knowledge Transfer for Long-term Battery Degradation Forecasting
作为一个简单案例,本notebook只使用了NASA数据集中的三组数据做训练和预测作为一个展示。模型使用放电终点的电压、温度等数据,训练和预测当前循环下的SOH情况。绘图使用SOH-cycles(电池循环圈数)作为YX轴,反应电池健康状态随循环圈数的降低,对一定循环后的曲线进行预测。
加载必须的工具包:
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.8/site-packages (1.3.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn) (3.1.0) Requirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.8/site-packages (from scikit-learn) (1.2.0) Requirement already satisfied: scipy>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn) (1.7.3) Requirement already satisfied: numpy>=1.17.3 in /opt/conda/lib/python3.8/site-packages (from scikit-learn) (1.22.4) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: numpy==1.22.4 in /opt/conda/lib/python3.8/site-packages (1.22.4) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: pandas in /opt/conda/lib/python3.8/site-packages (1.5.3) Requirement already satisfied: torch in /opt/conda/lib/python3.8/site-packages (1.13.1+cu116) Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/lib/python3.8/site-packages (from pandas) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.8/site-packages (from pandas) (2022.7) Requirement already satisfied: numpy>=1.20.3 in /opt/conda/lib/python3.8/site-packages (from pandas) (1.22.4) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch) (4.5.0) Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.8/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
设置超参,因为是展示案例,这里的超参设置也很简单:
加载NASA数据集,其中的5、6两组和7组的一部分作为训练数据,7组剩下的部分作为测试数据。请注意,这里对数据做了归一化处理,因此在后面的作图过程中,SOH的值会出现接近0的情况。
这里我们从model文件中引入了GPDM。这部分的代码位于[GPDM_model](https://bohrium.dp.tech/notebook/f7534f1beba3447185eb334926e27f95 #EGPDM_model.ipynb),请移步查询。
这里我们对GPDM模型做个简单介绍:
高斯过程动力学模型(GPDM)主要用于分析潜在变量(或低维嵌入)的动态,它包含从潜在空间到观察空间的非线性概率映射,以及潜在空间中的动态模型 。 可以表示为如下图所示的模型:
A和B代表basis函数的权重。
初始化EGPDM模型。
Num. of sequences = 1 [Data points = 168] Num. of sequences = 2 [Data points = 336] Num. of sequences = 3 [Data points = 504]
[array([[-2.89608849e-01, 2.02888185e-01, 2.77061292e-01], [-3.30741253e-01, 2.49343582e-01, 2.07089152e-01], [-3.77686797e-01, 2.25663076e-01, 1.55672353e-01], [-4.02081003e-01, 1.79565028e-01, 1.57284940e-01], [-4.13059198e-01, 1.31063198e-01, 1.67740050e-01], [-3.95991532e-01, 1.54441535e-01, 1.73788824e-01], [-3.95418046e-01, 1.61584036e-01, 1.71390526e-01], [-4.95798370e-01, 1.51874691e-01, 7.84441647e-02], [-5.02091869e-01, 1.22509236e-01, 8.44765631e-02], [-5.00002820e-01, 9.64273046e-02, 9.66809144e-02], [-5.09939147e-01, 4.94973002e-02, 1.09188236e-01], [-5.47780118e-01, 1.32990006e-01, 2.31722767e-02], [-5.50737475e-01, 1.06939030e-01, 3.16879636e-02], [-5.48188886e-01, 5.60147483e-02, 5.27281537e-02], [-5.94958279e-01, 2.39048040e-02, 3.34205153e-03], [-6.04021224e-01, -1.74464119e-02, 1.26783303e-02], [-6.01757079e-01, -1.70372909e-02, 1.64416843e-02], [-5.96644553e-01, -1.14247824e-02, 2.03547145e-02], [-5.86405948e-01, -9.75758058e-03, 2.71287110e-02], [-5.07359145e-01, 1.61520982e-01, 1.16624417e-01], [-4.98380488e-01, 2.44104764e-01, 9.53115349e-02], [-5.53386094e-01, 1.26450785e-01, 7.23006934e-02], [-6.09537784e-01, 2.15090248e-02, 4.63350657e-02], [-5.83465569e-01, -1.60979817e-02, 7.91124203e-02], [-6.32150453e-01, -5.43945712e-02, 6.21100302e-02], [-6.55145282e-01, -7.81973817e-02, 2.46607774e-02], [-6.42223196e-01, -1.16917070e-01, 5.28234598e-02], [-6.26680126e-01, -1.34108853e-01, 6.96966020e-02], [-6.42680064e-01, -1.21241356e-01, 2.44918795e-02], [-6.26549417e-01, -1.26816982e-01, 4.30324663e-02], [-6.70815383e-01, -4.77006555e-02, 8.56789203e-02], [-6.66311631e-01, 2.33864844e-02, 1.46430853e-02], [-6.80350366e-01, -6.25581574e-02, 1.22349473e-02], [-6.90366542e-01, -1.96756260e-01, 3.16293620e-02], [-6.83770000e-01, -2.27566364e-01, 3.80999927e-02], [-6.32412110e-01, -6.06895495e-02, -6.51172615e-04], [-6.44837078e-01, -1.64466593e-01, 3.53543129e-03], [-6.38365062e-01, -2.12169122e-01, 1.50130583e-02], [-6.33150502e-01, -2.58144542e-01, 1.26859185e-02], [-5.99282459e-01, -2.18979676e-01, 2.39746670e-02], [-5.50192448e-01, -2.93075468e-02, -2.57258967e-02], [-5.25301849e-01, 2.14738978e-02, -3.99092407e-02], [-5.48852074e-01, -1.19490721e-01, 7.82525888e-03], [-5.60998375e-01, -1.51682911e-01, 1.03909527e-03], [-5.64689056e-01, -2.71891451e-01, 1.91485399e-02], [-5.67273481e-01, -3.18251684e-01, 1.20108089e-02], [-5.36711337e-01, -3.04044670e-01, 1.58708198e-02], [-3.79558870e-01, 6.61948809e-01, -1.06794456e-01], [-4.80740097e-01, 3.78815354e-01, -9.70430451e-02], [-5.14840881e-01, 1.18492688e-01, -5.86866724e-02], [-5.12180988e-01, -1.21175649e-02, -2.96641756e-02], [-5.07756160e-01, -9.48950111e-02, -1.86506057e-02], [-4.80223035e-01, -9.44076547e-02, -9.73915711e-03], [-4.27248825e-01, 2.72015752e-02, -3.01085756e-02], [-4.23221740e-01, -5.61512456e-02, -1.91355202e-02], [-4.18498348e-01, -1.15334501e-01, -1.78119912e-02], [-4.08750622e-01, -1.54877306e-01, -7.36040783e-03], [-3.64022216e-01, -6.53445749e-02, -1.98540887e-02], [-3.27413242e-01, 3.67808253e-02, -4.55969476e-02], [-3.18185840e-01, -4.12003137e-02, -2.17967845e-02], [-3.10286727e-01, -5.34958613e-02, -3.40916550e-02], [-3.09776142e-01, -1.34690541e-01, -2.84627873e-02], [-2.61580699e-01, -5.31977886e-02, -2.25928817e-02], [-2.65489249e-01, -1.03661492e-01, -3.27600344e-02], [-2.31422672e-01, -5.11684361e-04, -5.81384252e-02], [-1.95769912e-01, -4.50374086e-02, -2.62979733e-02], [-1.91414319e-01, -8.57434579e-02, -3.47798684e-02], [-1.62782368e-01, -8.98190926e-02, -2.28020759e-02], [-1.43560734e-01, -6.90514593e-02, -2.94116581e-02], [-1.17270849e-01, -6.00366169e-02, -2.46304850e-02], [-9.32957868e-02, -7.00167415e-02, -1.60847636e-02], [-8.81543724e-02, -7.86036191e-02, -3.64482756e-02], [-5.71703849e-02, -7.81801738e-02, -2.44150109e-02], [-2.99381716e-02, -8.16398493e-02, -1.42857098e-02], [-3.09544431e-02, -9.88343295e-02, -3.61302356e-02], [ 2.41649088e-03, -1.07913999e-01, -1.89393675e-02], [ 1.27828822e-02, -1.69170061e-01, 9.71022573e-03], [ 4.72471616e-03, 1.45782520e-02, -4.04924055e-02], [ 2.04310937e-02, -1.62518000e-01, -1.13444385e-02], [ 3.29270344e-02, -2.51131527e-01, 8.41096772e-03], [ 6.75168892e-02, -2.10346628e-01, 5.77721892e-03], [ 1.27736776e-01, -5.93322985e-02, -8.11339511e-03], [ 1.55818302e-01, 2.27666467e-02, -3.06413544e-02], [ 1.80462849e-01, -1.86948395e-02, -1.04993762e-02], [ 1.72418095e-01, -9.66209783e-02, -1.19053429e-02], [ 1.77188574e-01, -1.85735247e-01, 4.32930460e-04], [ 2.28688563e-01, -8.80069797e-02, 2.88581537e-03], [ 2.44917740e-01, 1.74055306e-02, -4.10795082e-02], [ 2.62529312e-01, 2.79952717e-02, -4.43130450e-02], [ 3.98987220e-02, 1.20370758e-02, 2.14036953e-03], [ 1.54287338e-01, 1.48586346e-01, -5.63820505e-02], [ 2.06155518e-01, 1.37470656e-01, -5.21229495e-02], [ 2.09039064e-01, 7.37464514e-02, -6.53184770e-02], [ 2.46865070e-01, 1.09318400e-02, -2.49361903e-02], [ 2.53161307e-01, -1.58187728e-02, -3.49421140e-02], [ 2.83643002e-01, -2.04706961e-03, -2.94251223e-02], [ 3.04746782e-01, 8.95384726e-03, -3.16623945e-02], [ 3.33693768e-01, 1.72310119e-02, -2.50681098e-02], [ 3.36169762e-01, 1.90959755e-02, -5.11527175e-02], [ 3.59690442e-01, 3.25163243e-02, -5.13811824e-02], [ 3.81876102e-01, 3.10748404e-02, -4.74427786e-02], [ 3.96937248e-01, 2.49705767e-02, -4.64115787e-02], [ 4.19464858e-01, 4.00885928e-02, -9.28366077e-03], [ 3.98364343e-01, 2.49806974e-01, -8.00239424e-02], [ 4.48950249e-01, 1.92708054e-01, -5.77276688e-02], [ 4.58554002e-01, 1.49281665e-01, -6.14196867e-02], [ 4.92295407e-01, 6.64929266e-02, -4.15909593e-02], [ 4.91575198e-01, 8.41770552e-02, -5.14045603e-02], [ 5.36129256e-01, 1.69947679e-01, -4.74123949e-02], [ 5.64563468e-01, 2.13039024e-01, -5.82394772e-02], [ 5.50488098e-01, 1.31495950e-01, -6.40499829e-02], [ 5.61187940e-01, 8.61766011e-02, -5.14048300e-02], [ 6.08821162e-01, 1.70709302e-01, -4.79597086e-02], [ 6.30394056e-01, 2.51891012e-01, -7.72659510e-02], [ 6.33797841e-01, 2.07034959e-01, -6.99081726e-02], [ 6.47357321e-01, 1.43057809e-01, -4.85528763e-02], [ 6.55411418e-01, 8.73885241e-02, -3.33382056e-02], [ 6.91159355e-01, 1.64980567e-01, -3.56854476e-02], [ 7.10546988e-01, 2.86172251e-01, -8.14511370e-02], [ 6.30698051e-01, 2.16736143e-01, -4.87294927e-02], [ 6.12785409e-01, 2.92489084e-01, -7.89523867e-02], [ 6.41165192e-01, 1.90134788e-01, -7.16669215e-02], [ 6.52938294e-01, 9.58862653e-02, -5.27130919e-02], [ 6.76814885e-01, 6.38790752e-02, -3.66129265e-02], [ 7.02818350e-01, 6.59524066e-02, -3.01008980e-02], [ 7.12625625e-01, 4.61841121e-02, -2.73676243e-02], [ 7.39347709e-01, 4.10810897e-02, -1.88469124e-02], [ 7.49815914e-01, 1.04108633e-02, -1.45705267e-02], [ 7.51234285e-01, -8.62183367e-03, -1.90905593e-02], [ 7.87837395e-01, -1.85226466e-03, -5.59724940e-03], [ 8.07138692e-01, 1.08127201e-02, 3.70564491e-03], [ 8.16379763e-01, -6.50704861e-03, 1.87611201e-03], [ 8.12424371e-01, 1.59367013e-01, -3.74403695e-02], [ 7.60744208e-01, 2.19264902e-01, -7.20977598e-02], [ 8.06000785e-01, 1.67487995e-01, -6.02716565e-02], [ 8.51807848e-01, 1.01832030e-01, -1.35221733e-02], [ 8.34228805e-01, 2.53061522e-02, -2.42959634e-02], [ 8.75603131e-01, 1.39169298e-01, -3.60485800e-02], [ 9.10748686e-01, 1.94537672e-01, -3.07759617e-02], [ 9.09173650e-01, 1.16266109e-01, -1.54418038e-02], [ 9.16141244e-01, 2.72387453e-02, 1.10302157e-02], [ 9.29888732e-01, -5.87852606e-03, 1.98723078e-02], [ 9.55017584e-01, 9.38281153e-02, 2.30305830e-04], [ 9.52356364e-01, 1.01553464e-01, -1.71469357e-02], [ 9.68398400e-01, 2.57844134e-02, 9.61058537e-03], [ 9.66039158e-01, -5.06319122e-02, 2.28966904e-02], [ 9.68575743e-01, -5.81944887e-02, 1.39227772e-02], [ 9.95083819e-01, 5.49210229e-02, -8.80564069e-03], [ 1.01073959e+00, 3.16752135e-02, 1.27985075e-02], [ 9.80031975e-01, -9.41965190e-02, 5.21026895e-02], [ 8.64335584e-01, 6.19931924e-02, 1.84764686e-05], [ 9.21304889e-01, 6.06248342e-02, -1.00634795e-02], [ 9.35237670e-01, 1.37539342e-02, -9.20385522e-03], [ 9.55556446e-01, -8.32322462e-03, -2.87159757e-04], [ 9.78487214e-01, -2.63303013e-02, 1.13149734e-02], [ 9.89318292e-01, -2.80887667e-02, 5.69012083e-03], [ 1.00951056e+00, -3.45166777e-02, 2.37694713e-02], [ 1.02135097e+00, -4.10491711e-02, 2.02977273e-02], [ 1.02473753e+00, -4.57283153e-02, 1.25066154e-02], [ 1.04952330e+00, -3.02834065e-02, 2.61011876e-02], [ 1.07176699e+00, -2.29238959e-02, 4.13667122e-02], [ 1.07405368e+00, 1.26606290e-04, 1.89390321e-02], [ 1.09651902e+00, -2.76985915e-02, 4.65603092e-02], [ 1.08905149e+00, -2.67328682e-02, 2.89087467e-02], [ 1.09440254e+00, -8.43712064e-03, 1.10685548e-02], [ 1.11409562e+00, 5.59589810e-03, 1.88488122e-02], [ 1.03205286e+00, 9.91225825e-02, -1.73972970e-02], [ 9.95836205e-01, 1.84666763e-01, -3.79355135e-02]]), array([[-5.21072131e-01, 2.98161001e-01, 6.99970021e-02], [-5.29762098e-01, 3.32845377e-01, 3.72091896e-02], [-5.48457172e-01, 3.16201342e-01, 1.26745976e-02], [-5.67027224e-01, 2.90024556e-01, 9.37740060e-03], [-5.55415221e-01, 2.51205058e-01, 1.44821345e-02], [-5.96383698e-01, 2.69686609e-01, -2.55658319e-03], [-5.84247200e-01, 2.83760610e-01, 1.00139450e-04], [-5.46109782e-01, 2.64306408e-01, -3.00195455e-02], [-5.56048383e-01, 2.41434541e-01, -2.97402700e-02], [-5.36492479e-01, 2.19565107e-01, -2.25131520e-02], [-5.51090112e-01, 1.80572198e-01, -3.59249466e-02], [-5.22614536e-01, 2.58530632e-01, -6.12334536e-02], [-5.20452069e-01, 2.42386160e-01, -7.05596577e-02], [-5.14603330e-01, 1.99364632e-01, -6.64865781e-02], [-5.05447002e-01, 1.57518634e-01, -5.93374853e-02], [-4.96609891e-01, 1.20167458e-01, -5.60888294e-02], [-5.00508376e-01, 1.25453921e-01, -7.78897001e-02], [-4.79094598e-01, 1.28010834e-01, -7.88340901e-02], [-4.81122461e-01, 1.35585960e-01, -8.40240321e-02], [-6.07256021e-01, 8.74763024e-02, 9.15504101e-03], [-5.74823291e-01, 1.82842053e-01, -3.66077178e-02], [-5.88112087e-01, 7.44636168e-02, -2.21539047e-02], [-6.04505649e-01, -2.87894734e-02, -2.67065795e-02], [-6.40778566e-01, -4.58437825e-02, -1.49714467e-02], [-6.02375926e-01, -9.08258590e-02, -3.59820622e-02], [-5.65535214e-01, -1.27575651e-01, -2.78530003e-02], [-5.73750896e-01, -1.55594389e-01, -3.92171373e-02], [-5.59461888e-01, -1.69041659e-01, -4.19920471e-02], [-5.49938680e-01, -1.70608708e-01, -5.09600517e-02], [-5.54165711e-01, -1.50011039e-01, -4.38750237e-02], [-5.90840191e-01, -1.14574078e-01, 1.38505125e-02], [-5.48632526e-01, -6.12884755e-02, -3.80676359e-02], [-5.42847266e-01, -1.60196654e-01, -3.56838011e-02], [-5.43455540e-01, -2.50144562e-01, -2.56980471e-02], [-5.23858982e-01, -3.02920592e-01, -2.18463508e-02], [-4.76157141e-01, -1.61747958e-01, -4.88024650e-02], [-4.68544082e-01, -2.25133061e-01, -4.36957251e-02], [-4.56621017e-01, -2.82484721e-01, -3.70293239e-02], [-4.60737115e-01, -3.60852937e-01, -2.51593925e-02], [-4.25120625e-01, -2.93301371e-01, -4.09944621e-02], [-3.56072395e-01, -1.05578124e-01, -7.79124728e-02], [-3.18434855e-01, -6.23284177e-02, -8.33427704e-02], [-3.33452107e-01, -1.49065477e-01, -3.97762364e-02], [-3.89240377e-01, -2.01555087e-01, -4.29804572e-02], [-3.61408828e-01, -2.94846897e-01, -3.53835133e-02], [-3.46274624e-01, -3.50533337e-01, -3.37131097e-02], [-3.05932140e-01, -3.07180053e-01, -3.78625363e-02], [-2.28054765e-01, 5.41366335e-01, -1.24987221e-01], [-3.02016508e-01, 2.78779082e-01, -1.01133829e-01], [-3.25859828e-01, 3.42691602e-02, -7.33659095e-02], [-2.96794688e-01, -4.98802317e-02, -6.64242353e-02], [-2.97422338e-01, -1.28974734e-01, -6.05116210e-02], [-2.55359928e-01, -1.12198351e-01, -6.80238161e-02], [-2.04818316e-01, -2.00338963e-02, -8.23815685e-02], [-1.94362565e-01, -8.92544523e-02, -7.20077674e-02], [-1.84713574e-01, -1.46055440e-01, -5.96439303e-02], [-1.55148930e-01, -1.57536299e-01, -5.01324158e-02], [-1.12960635e-01, -1.03361164e-01, -5.49113170e-02], [-7.56107733e-02, -2.76171410e-02, -7.35186936e-02], [-4.99700835e-02, -4.09685725e-02, -6.66113455e-02], [-3.08504713e-02, -8.42219964e-02, -6.68706336e-02], [-9.00875997e-03, -1.30515388e-01, -4.19472641e-02], [ 2.64127882e-02, -5.21288744e-02, -5.37133017e-02], [ 1.47042411e-02, -1.76514253e-01, -3.19702314e-02], [ 3.42722632e-02, -9.97886944e-02, -5.40442729e-02], [ 7.21405325e-02, -1.46950924e-01, -3.30284269e-02], [ 8.28602127e-02, -1.80366369e-01, -2.17901591e-02], [ 1.04346370e-01, -1.79256800e-01, -2.24649363e-02], [ 1.38784948e-01, -1.56387259e-01, -2.21709029e-02], [ 1.70855318e-01, -1.58378113e-01, -2.12061416e-02], [ 1.72157393e-01, -1.53380778e-01, -1.61313644e-02], [ 1.94006750e-01, -1.56105992e-01, -8.00414972e-03], [ 2.09214116e-01, -1.53886973e-01, -1.35974225e-02], [ 2.30429152e-01, -1.54889190e-01, -6.07189892e-03], [ 2.48322333e-01, -1.69394891e-01, 3.90711262e-03], [ 2.63431921e-01, -1.83778385e-01, 4.15771415e-03], [ 3.02741820e-01, -1.70307083e-01, 1.79360611e-02], [ 2.64336693e-01, 3.85676953e-02, -2.42846544e-02], [ 2.78982251e-01, -1.33670331e-01, 1.13666977e-02], [ 3.03807018e-01, -2.11049539e-01, 3.50659524e-02], [ 3.29717163e-01, -1.76702994e-01, 2.41182693e-02], [ 3.78694189e-01, -2.73765593e-02, 2.56258916e-03], [ 4.07761598e-01, 4.60504869e-02, -1.24639248e-02], [ 4.27068085e-01, 1.90701598e-02, 2.76020338e-03], [ 4.39737602e-01, -5.93438347e-02, 1.77503837e-02], [ 4.38406216e-01, -1.27312384e-01, 4.17575006e-02], [ 4.61870955e-01, -4.46942065e-02, 1.97791227e-02], [ 4.94679371e-01, 4.88527248e-02, 7.41079441e-03], [ 5.07328730e-01, 6.21479630e-02, 3.29507808e-03], [ 2.42009557e-01, 7.38653774e-02, 3.42526639e-02], [ 2.77817036e-01, 1.24054182e-01, -2.55423562e-02], [ 3.39986595e-01, 1.00485741e-01, -1.95076406e-02], [ 3.66360679e-01, 3.74415555e-02, -1.00354751e-03], [ 3.83288996e-01, -1.79765333e-02, 7.75743609e-03], [ 4.10667226e-01, -4.52804083e-02, 2.20106080e-02], [ 4.34648029e-01, -2.78188011e-02, 1.68448408e-02], [ 4.65589446e-01, -1.67374752e-02, 1.84137792e-02], [ 4.82616670e-01, -1.65280090e-02, 2.27933931e-02], [ 4.99387663e-01, -1.27377479e-02, 2.45661147e-02], [ 5.24642761e-01, -1.30166769e-03, 2.30314465e-02], [ 5.40221476e-01, -6.00337166e-04, 2.61211882e-02], [ 5.49103368e-01, -4.96291638e-04, 2.37256451e-02], [ 5.87049539e-01, 7.96971377e-02, 2.06159818e-02], [ 5.31989651e-01, 2.39800286e-01, -3.10483134e-02], [ 5.74870771e-01, 1.89471128e-01, -1.31327731e-02], [ 6.06478597e-01, 1.41402515e-01, 3.04310461e-03], [ 6.03367173e-01, 7.93537526e-02, 1.55154714e-02], [ 6.21527383e-01, 8.46706611e-02, 1.86219399e-02], [ 6.51415120e-01, 1.60783354e-01, -3.22146234e-03], [ 6.72695221e-01, 2.03727880e-01, -5.53256923e-03], [ 6.81517395e-01, 1.38517131e-01, 1.72258442e-02], [ 6.84850623e-01, 9.34604590e-02, 2.80765006e-02], [ 7.21662702e-01, 1.58585499e-01, 1.51104960e-02], [ 7.34957659e-01, 2.39929244e-01, 6.93612455e-04], [ 7.37097001e-01, 1.98325648e-01, 1.02932755e-02], [ 7.39201442e-01, 1.36303648e-01, 2.65698102e-02], [ 7.40215858e-01, 8.32152201e-02, 3.95935136e-02], [ 7.57171783e-01, 1.54914121e-01, 1.69213362e-02], [ 7.97445826e-01, 2.63540027e-01, -2.00343186e-03], [ 7.15814556e-01, 2.13632429e-01, 5.77257723e-03], [ 6.78106931e-01, 2.99950473e-01, -2.21349436e-02], [ 7.19180647e-01, 1.93342527e-01, 7.33639960e-03], [ 7.19984243e-01, 9.74496906e-02, 2.76734322e-02], [ 7.46157567e-01, 7.48817243e-02, 3.82798385e-02], [ 7.49255413e-01, 5.96287968e-02, 3.86598844e-02], [ 7.64212452e-01, 5.58303554e-02, 4.28974632e-02], [ 7.80995896e-01, 4.35557467e-02, 4.39225437e-02], [ 7.86182414e-01, 6.57724535e-03, 5.27105221e-02], [ 7.90001732e-01, -1.37420886e-02, 5.49467192e-02], [ 8.08439727e-01, -4.90807743e-03, 5.82180507e-02], [ 8.09538699e-01, 3.19963108e-04, 5.59687767e-02], [ 8.13081708e-01, -2.04104875e-02, 5.77501294e-02], [ 8.26469108e-01, 8.76963270e-02, 3.53829221e-02], [ 7.44715018e-01, 8.81362733e-02, 1.75221438e-02], [ 7.64976016e-01, 1.02803764e-02, 2.86565667e-02], [ 7.95534188e-01, -5.29099186e-02, 5.02899423e-02], [ 7.83922977e-01, -1.40885248e-01, 6.72124632e-02], [ 8.26975502e-01, -3.43304610e-02, 5.11589164e-02], [ 8.39210852e-01, 2.86451617e-02, 3.55754685e-02], [ 8.46397073e-01, -4.71824397e-02, 5.33939061e-02], [ 8.39239808e-01, -1.43969665e-01, 6.90316618e-02], [ 8.34315555e-01, -1.83284561e-01, 7.94639183e-02], [ 8.61528375e-01, -9.49573699e-02, 5.83225065e-02], [ 8.67137660e-01, -8.07760155e-02, 5.69806520e-02], [ 8.65764438e-01, -1.70462375e-01, 7.42047424e-02], [ 8.57350763e-01, -2.47213734e-01, 8.80167692e-02], [ 8.77582454e-01, -2.59139743e-01, 9.98749861e-02], [ 8.96855364e-01, -1.54686843e-01, 7.46566107e-02], [ 8.92773267e-01, -1.75773225e-01, 7.17798199e-02], [ 8.82583679e-01, -2.15877059e-01, 8.75468201e-02], [ 8.31741855e-01, 1.82420743e-02, 1.90888903e-02], [ 8.97905555e-01, 4.41679460e-03, 3.16476223e-02], [ 9.10050814e-01, -5.77472228e-02, 4.01233683e-02], [ 9.37678733e-01, -8.69890508e-02, 4.71955324e-02], [ 9.55249724e-01, -1.03545452e-01, 5.84772164e-02], [ 9.54501790e-01, -1.21661836e-01, 5.61035979e-02], [ 9.76706544e-01, -1.21106733e-01, 5.62482829e-02], [ 9.88353915e-01, -1.33569969e-01, 6.08805199e-02], [ 9.98977740e-01, -1.39531705e-01, 6.34524052e-02], [ 1.00824342e+00, -1.35945101e-01, 5.34441912e-02], [ 1.01938269e+00, -1.23350740e-01, 4.85228425e-02], [ 1.03119576e+00, -1.15396496e-01, 4.56923307e-02], [ 1.04160132e+00, -1.45770827e-01, 5.69389803e-02], [ 1.06303130e+00, -1.69517094e-01, 5.06653011e-02], [ 1.05731280e+00, -1.32402373e-01, 4.75605606e-02], [ 1.08356286e+00, -1.18506472e-01, 5.31800008e-02], [ 1.01725688e+00, -6.35738728e-02, 1.24126766e-02], [ 1.00314213e+00, 3.21343585e-02, -1.68221221e-02]]), array([[-0.37397884, 0.24810242, 0.19301591], [-0.40920718, 0.27623097, 0.14429174], [-0.44443552, 0.30435951, 0.09556757], [-0.45308185, 0.30093416, 0.09110453], [-0.46140401, 0.29712401, 0.08718127], [-0.46593849, 0.27868263, 0.09169958], [-0.47037945, 0.25987998, 0.09642632], [-0.47449475, 0.2422494 , 0.09854635], [-0.47859783, 0.22466277, 0.10056864], [-0.50652135, 0.23453898, 0.07704477], [-0.53565097, 0.24580571, 0.05222743], [-0.5400111 , 0.25485697, 0.04533758], [-0.54278344, 0.26376622, 0.03959693], [-0.5457769 , 0.26353971, 0.03993749], [-0.54878758, 0.26260132, 0.04075191], [-0.5414516 , 0.25191942, 0.03657912], [-0.53316263, 0.24034009, 0.03194698], [-0.53126485, 0.22538052, 0.03830483], [-0.5300488 , 0.21006039, 0.04583494], [-0.5279155 , 0.19137718, 0.05471327], [-0.52567065, 0.17228496, 0.06375557], [-0.52209053, 0.20658641, 0.04127928], [-0.51832751, 0.24820206, 0.01448537], [-0.52453597, 0.24623348, 0.00813902], [-0.53226785, 0.23760619, 0.00491659], [-0.52748778, 0.21404437, 0.01739313], [-0.52059301, 0.18795841, 0.032523 ], [-0.54639911, 0.16998589, 0.02008784], [-0.57827823, 0.15352014, 0.00253346], [-0.58032467, 0.13757918, 0.00650651], [-0.5763181 , 0.1217447 , 0.01484745], [-0.56209352, 0.11419281, 0.01676941], [-0.54561497, 0.10846795, 0.01727543], [-0.55715151, 0.11443684, 0.00379124], [-0.57537821, 0.12319826, -0.01303388], [-0.57277095, 0.12273011, -0.01088334], [-0.56479736, 0.11988465, -0.00384511], [-0.56139216, 0.16039353, 0.01729237], [-0.55925205, 0.21290822, 0.04233426], [-0.55323617, 0.263885 , 0.03834022], [-0.54606968, 0.31440522, 0.02572613], [-0.57073713, 0.2800327 , 0.01962721], [-0.60551059, 0.21871011, 0.01559658], [-0.60829962, 0.15620067, 0.02762379], [-0.60025521, 0.09328922, 0.04508994], [-0.59271721, 0.05915699, 0.05960281], [-0.58536185, 0.03540416, 0.07305055], [-0.60038113, 0.02356641, 0.06589925], [-0.62397736, 0.01629612, 0.05085164], [-0.61840608, -0.00184396, 0.05184987], [-0.60097005, -0.02440564, 0.05937522], [-0.61020159, -0.04048794, 0.05408203], [-0.63092778, -0.0537774 , 0.04326361], [-0.63114618, -0.06363332, 0.04611143], [-0.62201017, -0.07192307, 0.05519296], [-0.60702377, -0.07510871, 0.05774637], [-0.58921666, -0.07583346, 0.05715228], [-0.59186115, -0.07379942, 0.05007473], [-0.60491738, -0.0703609 , 0.03969651], [-0.62592499, -0.05628077, 0.04783834], [-0.65120279, -0.0364857 , 0.06592613], [-0.64950931, -0.00518522, 0.05459069], [-0.63254908, 0.03262775, 0.0266006 ], [-0.62764434, 0.01719263, 0.01718833], [-0.62992652, -0.02998654, 0.0188513 ], [-0.63254137, -0.09685804, 0.03254606], [-0.63536494, -0.1760855 , 0.05379018], [-0.63642493, -0.21687818, 0.05750138], [-0.63632094, -0.23230389, 0.04964086], [-0.61801563, -0.17619098, 0.03658485], [-0.58708082, -0.07043901, 0.0199238 ], [-0.57444559, -0.06043899, 0.02145703], [-0.57515378, -0.12025809, 0.0362569 ], [-0.58083405, -0.16473299, 0.03759803], [-0.59032271, -0.19745489, 0.02863033], [-0.59002697, -0.22523778, 0.02952582], [-0.58186116, -0.24904799, 0.03835474], [-0.57298168, -0.24207431, 0.03554289], [-0.56349954, -0.20910533, 0.02290106], [-0.54510535, -0.14869676, 0.01039257], [-0.51881184, -0.06396675, -0.00199772], [-0.49924996, -0.00553603, -0.01348407], [-0.48595009, 0.02843021, -0.02412952], [-0.48323534, 0.01143135, -0.0152471 ], [-0.4908537 , -0.05531917, 0.01269825], [-0.50337586, -0.11385465, 0.02597833], [-0.52092142, -0.16397471, 0.02423544], [-0.52635326, -0.21749415, 0.02963776], [-0.51876285, -0.2746679 , 0.04272117], [-0.51562239, -0.31180228, 0.0452314 ], [-0.5175024 , -0.32632811, 0.03581292], [-0.51161918, -0.33240336, 0.03165203], [-0.49654267, -0.32847135, 0.03371723], [-0.45544491, -0.11003445, -0.00270852], [-0.38199639, 0.3750842 , -0.08698787], [-0.36302006, 0.596191 , -0.13025878], [-0.41516019, 0.47261567, -0.11999089], [-0.45049941, 0.34302362, -0.10456615], [-0.46279737, 0.20518004, -0.08206917], [-0.46864691, 0.09947616, -0.06391165], [-0.46520312, 0.04009125, -0.05200807], [-0.460615 , -0.01225981, -0.04225018], [-0.45429305, -0.05395351, -0.03574334], [-0.44443623, -0.07591484, -0.03124667], [-0.42894573, -0.06642768, -0.0299537 ], [-0.41211876, -0.04122558, -0.03370856], [-0.39305 , 0.01033708, -0.04593072], [-0.38069031, 0.02644698, -0.05237964], [-0.38018333, -0.02007632, -0.04862918], [-0.37249504, -0.05797799, -0.04261707], [-0.35143464, -0.07982561, -0.03239361], [-0.33747573, -0.09897912, -0.02398667], [-0.33746619, -0.11284062, -0.01914792], [-0.33214629, -0.11198456, -0.01793287], [-0.31581232, -0.08060319, -0.02423366], [-0.29955612, -0.04131342, -0.03387797], [-0.28347043, 0.01531395, -0.05085229], [-0.26839043, 0.04366466, -0.06002632], [-0.25564365, 0.00641349, -0.05110364], [-0.24879911, -0.02665055, -0.04584407], [-0.25646424, -0.04942122, -0.04958966], [-0.25727595, -0.07738437, -0.04811243], [-0.2402092 , -0.11889314, -0.03301047], [-0.22594302, -0.13528751, -0.02610473], [-0.21944207, -0.0820465 , -0.04192485], [-0.21027462, -0.0489695 , -0.05150235], [-0.19323461, -0.07542437, -0.04264925], [-0.17971056, -0.08423829, -0.03966161], [-0.17726177, -0.03748317, -0.05515015], [-0.17397476, -0.00774884, -0.06698107], [-0.16786429, -0.03534767, -0.06649159], [-0.16139665, -0.06250794, -0.06455043], [-0.15363923, -0.08808446, -0.05736708], [-0.14567885, -0.10822067, -0.05171746], [-0.13693049, -0.10723561, -0.05202229], [-0.12784285, -0.10453427, -0.05288006], [-0.11733447, -0.09464595, -0.05605324], [-0.10728915, -0.08629363, -0.05880804], [-0.09934308, -0.08490447, -0.05966624], [-0.09079933, -0.08495015, -0.05989352], [-0.07930982, -0.09206754, -0.05701126], [-0.06868611, -0.09941597, -0.05448401], [-0.06272439, -0.10800848, -0.05386833], [-0.05606917, -0.11554108, -0.05304687], [-0.04531067, -0.11680259, -0.05100799], [-0.03460983, -0.11842824, -0.04893086], [-0.02428635, -0.12243742, -0.04660341], [-0.01449227, -0.12716083, -0.04447589], [-0.00856279, -0.1370981 , -0.04380789], [-0.00180824, -0.14741766, -0.04313179], [ 0.01173026, -0.16088041, -0.04238914], [ 0.02384222, -0.17622171, -0.04067632], [ 0.02258039, -0.20917449, -0.02986814], [ 0.02220582, -0.23116651, -0.02190369], [ 0.03146432, -0.13415602, -0.04481407], [ 0.0399267 , -0.0516812 , -0.06430926], [ 0.03817211, -0.15574743, -0.03997603], [ 0.03744312, -0.25550914, -0.01651304], [ 0.0527134 , -0.28812039, -0.00662576], [ 0.06754951, -0.31841642, 0.00240434], [ 0.07381066, -0.30298682, -0.00549495], [ 0.08070719, -0.28538923, -0.01354502], [ 0.10454707, -0.20997852, -0.02561577], [ 0.12795551, -0.13536507, -0.03766027], [ 0.13389015, -0.09304024, -0.04864133], [ 0.13978912, -0.051452 , -0.05946912], [ 0.14276372, -0.07026409, -0.0577289 ], [ 0.14573832, -0.08907617, -0.05598867]])]
训练模型。
*********** TRAIN EGPDM *********** : - latent dimension: 3 - optimization steps: 20 - learning rate: 0.01 - optimizer: LBFGS - device: cuda Epoch:1/20 Running loss: -3.1820e+04 Used time: 2.2178149223327637 Epoch:2/20 Running loss: -1.7786e+05 Used time: 1.0954194068908691 Epoch:3/20 Running loss: -1.8785e+05 Used time: 1.1810853481292725 Epoch:4/20 Running loss: -2.1151e+05 Used time: 1.1811010837554932 Epoch:5/20 Running loss: -2.2016e+05 Used time: 1.0957603454589844 Epoch:6/20 Running loss: -2.3843e+05 Used time: 1.095186710357666 Epoch:7/20 Running loss: -2.5038e+05 Used time: 1.180229663848877 Epoch:8/20 Running loss: -2.7441e+05 Used time: 1.1381020545959473 Epoch:9/20 Running loss: -2.9195e+05 Used time: 1.1817424297332764 Epoch:10/20 Running loss: -3.0991e+05 Used time: 1.1424086093902588 Epoch:11/20 Running loss: -3.2911e+05 Used time: 1.1472959518432617 Epoch:12/20 Running loss: -3.4171e+05 Used time: 1.104860544204712 Epoch:13/20 Running loss: -3.5512e+05 Used time: 1.103034257888794 Epoch:14/20 Running loss: -3.6384e+05 Used time: 1.1456851959228516 Epoch:15/20 Running loss: -3.7283e+05 Used time: 1.1429014205932617 Epoch:16/20 Running loss: -3.7821e+05 Used time: 1.1042749881744385 Epoch:17/20 Running loss: -3.8277e+05 Used time: 1.1435697078704834 Epoch:18/20 Running loss: -3.8652e+05 Used time: 1.1888341903686523 Epoch:19/20 Running loss: -3.9009e+05 Used time: 1.1844356060028076 Epoch:20/20 Running loss: -3.9284e+05 Used time: 1.1018421649932861 Total Training Time: 23.88436007499695 s
处理绘图结果数据。
### START SAMPLING & PREDICTING... ###
绘制预测结果图,从图中我们可以看出,预测的衰退轨迹,即图中橙色线条的部分和grund truth有很好的相似度。
计算数据归一化后的预测精度RMSE。这里的RMSE相较于文章中稍高,这应该是数据量低和超参设置导致的。
### RESULT ### normalized rmse: 0.051306754005528164 Note that normalized rmse is greater than rmse of original data.
chenjm