新建
ANN for EIS prediction
微信用户DiH0
更新于 2024-12-25
推荐镜像 :Third-party software:ai4s-cup-0.1
推荐机型 :c2_m4_cpu
赞 1
数据集
AI4S-Cup-PulseEIS(v5)
[90]
# 载入各类包
import torch
import math
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import random
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
代码
文本
[91]
# 定义随机种子,保证每次训练、预测结果固定
def setup_seed(seed):
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
代码
文本
[92]
# 初始化随机种子
setup_seed(0)
代码
文本
[93]
# 多层感知器模型建立
class MLPModel(nn.Module):
def __init__(self):
super(MLPModel, self).__init__()
self.fc1 = nn.Linear(99, 128)
self.fc2 = nn.Linear(128, 102)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 99)
x = self.relu(self.fc1(x))
x = self.fc2(x)
x = x.view(-1, 51, 2)
return x
代码
文本
[94]
# 模型函数建立
model = MLPModel()
代码
文本
[95]
# 定义输入函数
def input_mlp(datasets, mode):
input_data = []
soc_lst = [f'{i*2}%SOC' for i in range(49)]
for i in datasets:
with open(f'/bohr/ai4spulseeis-lr97/v5/{mode}_datasets/{mode}_pulse_{i}.pkl', 'rb') as fp:
pulse_data = pickle.load(fp, encoding='bytes')
for soc in soc_lst:
Vol = pulse_data[soc]['Voltage']
tensor_vol = torch.tensor(Vol).view(1,99,1)
input_data.append(tensor_vol)
return input_data
代码
文本
[96]
# 准备输入变量
train_baty = [1,2,3,4,5,6]
test_baty = [1,2]
input_train = input_mlp(datasets = train_baty, mode = 'train')
input_test = input_mlp(datasets = test_baty, mode = 'test')
代码
文本
[97]
# 定义目标函数
def target_mlp(datasets, mode):
target_data = []
soc_lst = [f'{i*2}%SOC' for i in range(49)]
EIS_list = []
for k in datasets:
with open(f'/bohr/ai4spulseeis-lr97/v5/{mode}_datasets/{mode}_eis_{k}.pkl', 'rb') as fp:
eis_data = pickle.load(fp, encoding='bytes')
for soc in soc_lst:
EIS_tot = [[],[]]
re = eis_data[soc]['Real']
im = eis_data[soc]['Imaginary']
EIS_tot[0] = re
EIS_tot[1] = im
EIS_list.append(EIS_tot)
EIS_list = [np.array(t).squeeze().T for t in EIS_list]
target_data = [torch.tensor(t).float().view(1, 51, 2) for t in EIS_list]
return target_data
代码
文本
[98]
# 准备目标函数
target_train = target_mlp(datasets = train_baty, mode = 'train')
代码
文本
[99]
# 规定模型的基本参量
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=40, verbose=True, min_lr=5e-6)
num_epochs = 5000
batch_size = 8
代码
文本
[100]
#测试模型epoch的大小对RMSE的影响
train_losses = [] # 记录每个epoch的训练损失
for epoch in range(num_epochs):
epoch_loss = 0.0
train_size = len(input_train)
for batch_idx in range(0, train_size, batch_size):
batch_input = torch.cat(input_train[batch_idx:batch_idx+batch_size], dim=0)
batch_target = torch.cat(target_train[batch_idx:batch_idx+batch_size], dim=0)
optimizer.zero_grad()
outputs = model(batch_input)
loss = criterion(outputs, batch_target)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * batch_input.shape[0]
scheduler.step(epoch_loss / (train_size / batch_size))
train_losses.append(epoch_loss / (train_size / batch_size))
if (epoch + 1) % 50 == 0:
mean_mse = epoch_loss / (train_size / batch_size)
rmse = mean_mse ** 0.5
print(f"Epoch [{epoch + 1}/{num_epochs}] | Loss: {mean_mse:.4f} | RMSE: {rmse:.4f}")
Epoch [50/5000] | Loss: 42.0119 | RMSE: 6.4817 Epoch 00062: reducing learning rate of group 0 to 8.0000e-05. Epoch [100/5000] | Loss: 41.9762 | RMSE: 6.4789 Epoch 00103: reducing learning rate of group 0 to 6.4000e-05. Epoch 00146: reducing learning rate of group 0 to 5.1200e-05. Epoch [150/5000] | Loss: 41.7950 | RMSE: 6.4649 Epoch [200/5000] | Loss: 41.7837 | RMSE: 6.4640 Epoch [250/5000] | Loss: 41.7694 | RMSE: 6.4629 Epoch [300/5000] | Loss: 41.7548 | RMSE: 6.4618 Epoch [350/5000] | Loss: 41.7400 | RMSE: 6.4607 Epoch [400/5000] | Loss: 41.7252 | RMSE: 6.4595 Epoch [450/5000] | Loss: 41.7105 | RMSE: 6.4584 Epoch [500/5000] | Loss: 41.6957 | RMSE: 6.4572 Epoch [550/5000] | Loss: 41.6809 | RMSE: 6.4561 Epoch [600/5000] | Loss: 41.6661 | RMSE: 6.4549 Epoch [650/5000] | Loss: 41.6513 | RMSE: 6.4538 Epoch [700/5000] | Loss: 41.6365 | RMSE: 6.4526 Epoch [750/5000] | Loss: 41.6217 | RMSE: 6.4515 Epoch [800/5000] | Loss: 41.6069 | RMSE: 6.4503 Epoch [850/5000] | Loss: 41.5921 | RMSE: 6.4492 Epoch [900/5000] | Loss: 41.5772 | RMSE: 6.4480 Epoch [950/5000] | Loss: 41.5624 | RMSE: 6.4469 Epoch [1000/5000] | Loss: 41.5475 | RMSE: 6.4457 Epoch [1050/5000] | Loss: 41.5326 | RMSE: 6.4446 Epoch [1100/5000] | Loss: 41.5178 | RMSE: 6.4434 Epoch [1150/5000] | Loss: 41.5029 | RMSE: 6.4423 Epoch [1200/5000] | Loss: 41.4880 | RMSE: 6.4411 Epoch [1250/5000] | Loss: 41.4730 | RMSE: 6.4400 Epoch [1300/5000] | Loss: 41.4581 | RMSE: 6.4388 Epoch [1350/5000] | Loss: 41.4431 | RMSE: 6.4376 Epoch [1400/5000] | Loss: 41.4282 | RMSE: 6.4365 Epoch [1450/5000] | Loss: 41.4132 | RMSE: 6.4353 Epoch [1500/5000] | Loss: 41.3981 | RMSE: 6.4341 Epoch [1550/5000] | Loss: 41.3831 | RMSE: 6.4330 Epoch [1600/5000] | Loss: 41.3681 | RMSE: 6.4318 Epoch [1650/5000] | Loss: 41.3530 | RMSE: 6.4306 Epoch [1700/5000] | Loss: 41.3379 | RMSE: 6.4295 Epoch [1750/5000] | Loss: 41.3228 | RMSE: 6.4283 Epoch [1800/5000] | Loss: 41.3076 | RMSE: 6.4271 Epoch [1850/5000] | Loss: 41.2925 | RMSE: 6.4259 Epoch [1900/5000] | Loss: 41.2773 | RMSE: 6.4247 Epoch [1950/5000] | Loss: 41.2621 | RMSE: 6.4236 Epoch [2000/5000] | Loss: 41.2468 | RMSE: 6.4224 Epoch [2050/5000] | Loss: 41.2315 | RMSE: 6.4212 Epoch [2100/5000] | Loss: 41.2162 | RMSE: 6.4200 Epoch [2150/5000] | Loss: 41.2009 | RMSE: 6.4188 Epoch [2200/5000] | Loss: 41.1855 | RMSE: 6.4176 Epoch [2250/5000] | Loss: 41.1701 | RMSE: 6.4164 Epoch [2300/5000] | Loss: 41.1547 | RMSE: 6.4152 Epoch [2350/5000] | Loss: 41.1393 | RMSE: 6.4140 Epoch [2400/5000] | Loss: 41.1238 | RMSE: 6.4128 Epoch [2450/5000] | Loss: 41.1082 | RMSE: 6.4116 Epoch [2500/5000] | Loss: 41.0927 | RMSE: 6.4104 Epoch [2550/5000] | Loss: 41.0771 | RMSE: 6.4091 Epoch [2600/5000] | Loss: 41.0614 | RMSE: 6.4079 Epoch [2650/5000] | Loss: 41.0458 | RMSE: 6.4067 Epoch [2700/5000] | Loss: 41.0300 | RMSE: 6.4055 Epoch [2750/5000] | Loss: 41.0143 | RMSE: 6.4042 Epoch [2800/5000] | Loss: 40.9985 | RMSE: 6.4030 Epoch [2850/5000] | Loss: 40.9826 | RMSE: 6.4018 Epoch [2900/5000] | Loss: 40.9667 | RMSE: 6.4005 Epoch [2950/5000] | Loss: 40.9508 | RMSE: 6.3993 Epoch [3000/5000] | Loss: 40.9348 | RMSE: 6.3980 Epoch [3050/5000] | Loss: 40.9188 | RMSE: 6.3968 Epoch [3100/5000] | Loss: 40.9027 | RMSE: 6.3955 Epoch [3150/5000] | Loss: 40.8866 | RMSE: 6.3943 Epoch [3200/5000] | Loss: 40.8704 | RMSE: 6.3930 Epoch [3250/5000] | Loss: 40.8541 | RMSE: 6.3917 Epoch [3300/5000] | Loss: 40.8379 | RMSE: 6.3905 Epoch [3350/5000] | Loss: 40.8215 | RMSE: 6.3892 Epoch [3450/5000] | Loss: 40.7886 | RMSE: 6.3866 Epoch [3500/5000] | Loss: 40.7721 | RMSE: 6.3853 Epoch [3550/5000] | Loss: 40.7555 | RMSE: 6.3840 Epoch [3600/5000] | Loss: 40.7389 | RMSE: 6.3827 Epoch [3650/5000] | Loss: 40.7222 | RMSE: 6.3814 Epoch [3700/5000] | Loss: 40.7054 | RMSE: 6.3801 Epoch [3750/5000] | Loss: 40.6886 | RMSE: 6.3788 Epoch [3800/5000] | Loss: 40.6717 | RMSE: 6.3774 Epoch [3850/5000] | Loss: 40.6547 | RMSE: 6.3761 Epoch [3900/5000] | Loss: 40.6376 | RMSE: 6.3748 Epoch [3950/5000] | Loss: 40.6205 | RMSE: 6.3734 Epoch [4000/5000] | Loss: 40.6033 | RMSE: 6.3721 Epoch [4050/5000] | Loss: 40.5860 | RMSE: 6.3707 Epoch [4100/5000] | Loss: 40.5687 | RMSE: 6.3694 Epoch [4150/5000] | Loss: 40.5512 | RMSE: 6.3680 Epoch [4200/5000] | Loss: 40.5337 | RMSE: 6.3666 Epoch [4250/5000] | Loss: 40.5161 | RMSE: 6.3652 Epoch [4300/5000] | Loss: 40.4984 | RMSE: 6.3638 Epoch [4350/5000] | Loss: 40.4806 | RMSE: 6.3624 Epoch [4400/5000] | Loss: 40.4627 | RMSE: 6.3610 Epoch [4450/5000] | Loss: 40.4448 | RMSE: 6.3596 Epoch [4500/5000] | Loss: 40.4267 | RMSE: 6.3582 Epoch [4550/5000] | Loss: 40.4085 | RMSE: 6.3568 Epoch [4600/5000] | Loss: 40.3903 | RMSE: 6.3553 Epoch [4650/5000] | Loss: 40.3719 | RMSE: 6.3539 Epoch [4700/5000] | Loss: 40.3534 | RMSE: 6.3524 Epoch [4750/5000] | Loss: 40.3348 | RMSE: 6.3510 Epoch [4800/5000] | Loss: 40.3161 | RMSE: 6.3495 Epoch [4850/5000] | Loss: 40.2973 | RMSE: 6.3480 Epoch [4900/5000] | Loss: 40.2784 | RMSE: 6.3465 Epoch [4950/5000] | Loss: 40.2593 | RMSE: 6.3450 Epoch [5000/5000] | Loss: 40.2401 | RMSE: 6.3435
代码
文本
[101]
# 绘制训练过程中loss的变化曲线(对数作图)
losses_log = [math.log(t) for t in train_losses]
plt.plot(losses_log, label='Training loss')
plt.xlabel('Epochs')
plt.ylabel('MSE Loss')
plt.legend()
plt.show()
代码
文本
[102]
# 使用测试数据进行预测
def evaluation():
model.eval()
outputs_eval = []
for batch_idx in range(0, len(input_test), batch_size):
batch_input = torch.cat(input_test[batch_idx:batch_idx+batch_size], dim=0)
with torch.no_grad():
output = model(batch_input)
outputs_eval.append(output)
return outputs_eval
代码
文本
[103]
outputs_eval = evaluation()
predict_outputs = torch.cat(outputs_eval, dim=0).view(1, len(input_test)*51, 2).tolist()
predict_data = np.array(predict_outputs).squeeze().T
代码
文本
[104]
# 给出需要提交的数据结果
results_data = {}
results_data['test_data_number'] = [1 for i in range(49*51)] + [2 for i in range(49*51)]
results_data['SOC(%)'] = []
for i in range(49):
results_data['SOC(%)'] += [i*2 for j in range(51)]
for i in range(49):
results_data['SOC(%)'] += [i*2 for j in range(51)]
results_data['EIS_real'] = predict_data[0].tolist()
results_data['EIS_imaginary'] = predict_data[1].tolist()
代码
文本
[105]
# 对EIS预测结果作图
for i in range(49):
plt.plot(predict_data[0][51*i:51*(i+1)],predict_data[1][51*i:51*(i+1)],'o-')
plt.title('predict EIS RMSE: 6.3435')
plt.xlabel('Z_re')
plt.ylabel('Z_im')
plt.show()
代码
文本
已赞1