Bohrium
robot
新建

空间站广场

科学导航
Notebooks
比赛
课程
Apps
镜像市场
实验室
Uni-Lab
我的主页
我的Notebooks
我的知识库
我的足迹

我的工作空间

任务
节点
镜像
文件
数据集
项目
数据库
公开
基于Transformer网络预测锂电池的老化轨迹
Deep Learning
transformer
中文
Deep Learningtransformer中文
陈乐天 Letian Chen
更新于 2025-01-28
推荐镜像 :Basic Image:bohrium-notebook:2023-04-07
推荐机型 :c2_m4_cpu
赞 3
4
7
NASA_data(v1)

基于Transformer网络预测锂电池的老化轨迹

©️ Copyright 2023 @ Authors
作者: 陈乐天📨
日期:2024-05-21
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
快速开始:点击上方的 开始连接 按钮,选择 bohrium-notebook:2023-04-07镜像c16_m16_cpu 节点配置,稍等片刻即可运行。 本文数据与代码均来自文献,文献信息详见文末。

代码
文本

主要内容

准确预测锂电池的剩余寿命(RUL)在管理电池健康状态和估计电池状态方面发挥着重要作用。随着电动汽车的快速发展,对预测RUL技术的需求与日俱增。为了预测RUL,本案例中设计了一个基于Transformer的神经网络。

  • 首先,电池容量数据通常存在大量噪声,特别是在电池充放电再生过程中。为了缓解这个问题,我们对原始数据进行了去噪自动编码器(DAE)的处理。
  • 然后,为了捕获时序信息和学习有用的特征,重构后的序列被输入到Transformer网络中。
  • 最后,为了统一去噪和预测两个任务,本案例将其组合到一个统一的框架中。

在NASA数据集上的大量实验和与一些现有方法的比较结果表明,本案例中提出的方法在预测RUL方面具有更好的表现。

Transformer是一种神经网络架构,主要用于自然语言处理(NLP)任务,如机器翻译、文本摘要等。它是由Vaswani等人在2017年提出的一种基于自注意力(Self-Attention)机制的深度学习模型。Transformer摒弃了传统的循环神经网络(RNN)和长短时记忆网络(LSTM),利用自注意力机制处理输入序列中的长距离依赖关系。这使得Transformer在处理序列数据时具有更高的并行性和计算效率。此外,Transformer引入了位置编码(Positional Encoding)来捕捉序列中单词的顺序信息。

代码
文本

导入需要的库

代码
文本
[1]
import numpy as np
import random
import math
import os
import scipy.io
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision
import transformers
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm
from math import sqrt
from datetime import datetime
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
代码
文本
  • 显示当前机器的线程数
代码
文本
[2]
print("Number of threads before setting: ", torch.get_num_threads())
Number of threads before setting:  8
代码
文本
  • 调整线程数,提高CPU利用率
代码
文本
[3]
torch.set_num_threads(12) # 设置为你的CPU核心数
代码
文本

1. 查看并加载数据

1.1 NASA数据集介绍

NASA 数据集可从 NASA Ames研究中心网站 获取,其中包含四种不同锂离子电池的记录,每个锂离子电池数据都包含三种:充电放电阻抗测量

一组四个锂离子电池(#5、6、7和18)在室温下经历了3种不同的操作模式(充电、放电和阻抗)。

  • 充电 是在恒定电流(CC)模式下以 1.5A 进行,直到电池电压达到 4.2V,然后在恒定电压(CV)模式下继续进行,直到充电电流降至 20mA。
  • 放电 是在恒定电流(CC)模式下以 2A 进行,直到电池电压降至2.7V、2.5V、2.2V 和 2.5V(对应电池5、6、7和18)。
  • 阻抗测量 是通过电化学阻抗谱(EIS)频率扫描从 0.1Hz 到 5kHz 进行的。反复的充放电循环会加速电池的老化,而阻抗测量提供了电池内部参数随老化进展的变化信息。当电池达到寿命终点(EOL)标准时,实验停止,该标准是额定容量(从 2Ahr 到 1.4Ahr)下降30%。该数据集可用于预测剩余电量(对于给定的放电循环)和剩余使用寿命(RUL)。

数据集的文件中包括三十多个电池的数据,本教程取1. BatteryAgingARC-FY08Q4/中的数据进行展示,该文件夹包含如下数据

B0005.mat: 电池#5的数据
B0006.mat: 电池#6的数据
B0007.mat: 电池#7的数据
B0018.mat: 电池#18的数据
代码
文本
已隐藏单元格
代码
文本

1.2 提取原始数据

代码
文本
[4]
import random
import datetime
import os
from sklearn.metrics import mean_absolute_error, mean_squared_error
from math import sqrt
import matplotlib.pyplot as plt

# 显示时间
def convert_to_time(hmm):
year, month, day, hour, minute, second = int(hmm[0]), int(hmm[1]), int(hmm[2]), int(hmm[3]), int(hmm[4]), int(hmm[5])
return datetime.datetime(year=year, month=month, day=day, hour=hour, minute=minute, second=second)

# 加载 .mat 数据
def loadMat(matfile):
data = scipy.io.loadmat(matfile)
filename = matfile.split("/")[-1].split(".")[0]
col = data[filename]
col = col[0][0][0][0]
size = col.shape[0]

data = []
for i in range(size):
k = list(col[i][3][0].dtype.fields.keys())
d1, d2 = {}, {}
if str(col[i][0][0]) != 'impedance':
for j in range(len(k)):
t = col[i][3][0][0][j][0]
l = [t[m] for m in range(len(t))]
d2[k[j]] = l
d1['type'], d1['temp'], d1['time'], d1['data'] = str(col[i][0][0]), int(col[i][1][0]), str(convert_to_time(col[i][2][0])), d2
data.append(d1)

return data

# 读取电池容量数据“capacity”
def getBatteryCapacity(Battery):
cycle, capacity = [], []
i = 1
for Bat in Battery:
if Bat['type'] == 'discharge':
capacity.append(Bat['data']['Capacity'][0])
cycle.append(i)
i += 1
return [cycle, capacity]

# 读取电池充电量数据“charge”或放电量数据“discharge”
def getBatteryValues(Battery, Type='charge'):
values_data = []
for Bat in Battery:
if Bat['type'] == Type:
values_data.append(Bat['data'])
return values_data
代码
文本
[5]
# 主程序
Battery_list = ['B0005', 'B0006', 'B0007', 'B0018']
dir_path = '/bohr/data-1a57/v1/NASA/1. BatteryAgingARC-FY08Q4/'

Battery = {}
for name in Battery_list:
print('Load Dataset ' + name + '.mat ...')
path = dir_path + name + '.mat'
data = loadMat(path)
Battery[name] = {
'capacity': getBatteryCapacity(data),
'charge': getBatteryValues(data, 'charge'),
'discharge': getBatteryValues(data, 'discharge')
}
Load Dataset B0005.mat ...
Load Dataset B0006.mat ...
Load Dataset B0007.mat ...
Load Dataset B0018.mat ...
代码
文本
[6]
# 画图函数,只画cycle=0,50,100的数据
def plotBatteryCycles(Battery, battery_name, cycles=[0, 50, 100]):
charge_data = Battery[battery_name]['charge']
discharge_data = Battery[battery_name]['discharge']
# 画充电数据,查看电流
plt.figure(figsize=(8, 4))
for cycle in cycles:
if cycle < len(charge_data):
plt.plot(charge_data[cycle]['Time'], charge_data[cycle]['Current_measured'], label=f'Cycle {cycle} Current')
plt.xlabel('Time (s)')
plt.ylabel('Current (A)')
plt.title(f'{battery_name} - Charge Cycles {cycles} - Current')
plt.legend()
plt.show()
# 画放电数据,查看电压
plt.figure(figsize=(8, 4))
for cycle in cycles:
if cycle < len(discharge_data):
plt.plot(discharge_data[cycle]['Time'], discharge_data[cycle]['Voltage_measured'], label=f'Cycle {cycle} Voltage')
plt.xlabel('Time (s)')
plt.ylabel('Voltage (V)')
plt.title(f'{battery_name} - Discharge Cycles {cycles} - Voltage')
plt.legend()
plt.show()

# 选择要绘制的电池
battery_name = 'B0005'
plotBatteryCycles(Battery, battery_name)
代码
文本

1.3 绘制容量衰减曲线图

展示一下原始数据中容量衰减的曲线。

代码
文本
[7]
fig, ax = plt.subplots(1, figsize=(12, 8))
color_list = ['b:', 'g--', 'r-.', 'c.']
c = 0
for name,color in zip(Battery_list, color_list):
df_result = Battery[name]['capacity']
ax.plot(df_result[0], df_result[1], color, label=name)
ax.set(xlabel='Discharge cycles', ylabel='Capacity (Ah)', title='Capacity degradation at ambient temperature of 24°C')
plt.legend()
<matplotlib.legend.Legend at 0x7f6a27a02e80>
代码
文本

2. 数据分割与处理

代码
文本

2.1 数据构建函数

  • 数据构建函数: 用于将时间序列数据构建成模型训练所需的特征和目标对。
代码
文本
[8]
def build_instances(sequence, window_size):
"""
构建数据实例。

Args:
sequence (list or np.array): 时间序列数据。
window_size (int): 滑动窗口大小。

Returns:
tuple: 特征数组和目标数组。
"""
x, y = [], []
for i in range(len(sequence) - window_size):
features = sequence[i:i+window_size]
target = sequence[i+window_size]
x.append(features)
y.append(target)
return np.array(x).astype(np.float32), np.array(y).astype(np.float32)

代码
文本

2.2 数据分割函数

  • 数据分割函数: 用于将时间序列数据分割为训练集和测试集。
代码
文本
[9]
def split_dataset(data_sequence, train_ratio=0.0, capacity_threshold=0.0):
"""
分割数据集。

Args:
data_sequence (list or np.array): 时间序列数据。
train_ratio (float): 训练集比例,0 到 1 之间。
capacity_threshold (float): 容量阈值,用于确定分割点。

Returns:
tuple: 训练数据和测试数据。
"""
if capacity_threshold > 0:
max_capacity = max(data_sequence)
capacity = max_capacity * capacity_threshold
point = [i for i in range(len(data_sequence)) if data_sequence[i] < capacity]
else:
point = int(train_ratio + 1)
if 0 < train_ratio <= 1:
point = int(len(data_sequence) * train_ratio)
train_data, test_data = data_sequence[:point], data_sequence[point:]
return train_data, test_data

代码
文本

2.3 数据处理和评估函数

  • 数据处理和评估函数: 用于留一评估、计算相对误差和评估指标。

留一训练

get_train_test 函数中:

  1. data_sequence = data_dict[name]['capacity'][1] 获取当前电池(name)的容量数据。
  2. train_data, test_data = data_sequence[:window_size+1], data_sequence[window_size+1:] 将当前电池数据按窗口大小分割,前window_size+1个数据用于训练,剩余数据用于测试。
  3. for k, v in data_dict.items(): 循环其他电池的数据,将其全部用于训练。 这意味着,每次调用 get_train_test 时,都会将 name 指定的电池数据作为测试集,其他所有电池的数据作为训练集。

具体来说:

Battery_list 包含 ['B0005', 'B0006', 'B0007', 'B0018'] 四个电池,那么留一评估会进行四次,每次使用不同的电池作为测试集,具体如下:

  • 第一次:使用 'B0005' 作为测试集,'B0006', 'B0007', 'B0018' 作为训练集。
  • 第二次:使用 'B0006' 作为测试集,'B0005', 'B0007', 'B0018' 作为训练集。
  • 第三次:使用 'B0007' 作为测试集,'B0005', 'B0006', 'B0018' 作为训练集。
  • 第四次:使用 'B0018' 作为测试集,'B0005', 'B0006', 'B0007' 作为训练集。
代码
文本
[10]
def get_train_test(data_dict, name, window_size=8):
"""
留一评估: 随机抽取一个电池,其余的用于训练。

Args:
data_dict (dict): 数据字典。
name (str): 要留作测试的数据名称。
window_size (int): 滑动窗口大小。

Returns:
tuple: 训练特征、训练目标、训练数据和测试数据。
"""
data_sequence = data_dict[name]['capacity'][1]
train_data, test_data = data_sequence[:window_size+1], data_sequence[window_size+1:]
train_x, train_y = build_instances(train_data, window_size)
for k, v in data_dict.items():
if k != name:
data_x, data_y = build_instances(v['capacity'][1], window_size)
train_x, train_y = np.r_[train_x, data_x], np.r_[train_y, data_y]
return train_x, train_y, list(train_data), list(test_data)

def relative_error(y_test, y_predict, threshold):
"""
计算相对误差。

Args:
y_test (list or np.array): 测试集真实值。
y_predict (list or np.array): 预测值。
threshold (float): 容量阈值。

Returns:
float: 相对误差分数。
"""
true_re, pred_re = len(y_test), 0
for i in range(len(y_test)-1):
if y_test[i] <= threshold >= y_test[i+1]:
true_re = i - 1
break
for i in range(len(y_predict)-1):
if y_predict[i] <= threshold:
pred_re = i - 1
break
score = abs(true_re - pred_re) / true_re
if score > 1: score = 1
return score

def evaluation(y_test, y_predict):
"""
计算评估指标。

Args:
y_test (list or np.array): 测试集真实值。
y_predict (list or np.array): 预测值。

Returns:
tuple: 均方根误差(RMSE)和平均绝对误差(MAE)。
"""
mse = mean_squared_error(y_test, y_predict)
rmse = sqrt(mse)
mae = mean_absolute_error(y_test, y_predict)
return rmse, mae

代码
文本

2.4 辅助函数

  • 辅助函数: 用于设置随机种子,确保结果的可重复性。
代码
文本
[11]
def setup_seed(seed):
"""
设置随机种子。

Args:
seed (int): 随机种子值。
"""
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

代码
文本

3. GRU模型的搭建与性能

3.1 构建GRU模型

代码
文本
[12]
import datetime
# 定义GRU模型
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, noise_level=0.01):
super(GRUModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.noise_level = noise_level
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)

def forward(self, x):
# 添加噪声
if self.noise_level > 0:
noise = self.noise_level * torch.randn_like(x)
x = x + noise
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out


# 初始化模型参数
def reset_weights(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
m.reset_parameters()
代码
文本

3.2 定义训练函数

代码
文本
[13]
# 训练和评估GRU模型
def train_and_evaluate(data_dict, window_size=8, num_epochs=100, learning_rate=0.001, hidden_size=32, num_layers=1, noise_level=0.01):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
setup_seed(0)

for name in data_dict.keys():
train_x, train_y, train_data, test_data = get_train_test(data_dict, name, window_size)
train_x = torch.from_numpy(train_x).unsqueeze(-1).to(device)
train_y = torch.from_numpy(train_y).unsqueeze(-1).to(device)

model = GRUModel(input_size=1, hidden_size=hidden_size, num_layers=num_layers, noise_level=noise_level).to(device)
model.apply(reset_weights)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
model.train()
for epoch in range(num_epochs):
outputs = model(train_x)
optimizer.zero_grad()
loss = criterion(outputs, train_y)
loss.backward()
optimizer.step()

# 评估模型
model.eval()
test_x, test_y = build_instances(test_data, window_size)
test_x = torch.from_numpy(test_x).unsqueeze(-1).to(device)
test_y = np.array(test_y)
with torch.no_grad():
predictions = model(test_x).cpu().numpy()

rmse, mae = evaluation(test_y, predictions)
print(f'{name} - RMSE: {rmse:.4f}, MAE: {mae:.4f}')

# 绘制预测曲线与实际数据曲线
plt.figure(figsize=(8, 4))
plt.plot(range(len(test_y)), test_y, label='Actual')
plt.plot(range(len(predictions)), predictions, label='Predicted')
plt.xlabel('Time')
plt.ylabel('Capacity (Ah)')
plt.title(f'{name} - Actual vs Predicted')
plt.legend()
plt.show()

代码
文本

3.3 训练模型并评估性能

代码
文本
[14]
# 训练和评估GRU模型
train_and_evaluate(Battery, window_size=16, num_epochs=1000, learning_rate=0.01, hidden_size=64, num_layers=2, noise_level=0.01)
B0005 - RMSE: 0.0191, MAE: 0.0133
B0006 - RMSE: 0.0221, MAE: 0.0126
B0007 - RMSE: 0.0243, MAE: 0.0207
B0018 - RMSE: 0.0264, MAE: 0.0141
代码
文本

4. Transformer架构与性能

4.1 构建Transformer网络模型

代码
文本
[15]
class Autoencoder(nn.Module):
def __init__(self, input_size=16, hidden_dim=8, noise_level=0.01):
'''
Args:
input_size: the feature size of input data (required).
hidden_dim: the hidden size of AutoEncoder (required).
noise_level: the noise level added in Autoencoder (optional).
'''
super(Autoencoder, self).__init__()
self.noise_level = noise_level
self.fc1 = nn.Linear(input_size, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, input_size)
def encoder(self, x):
x = self.fc1(x)
h1 = F.relu(x)
return h1
def mask(self, x):
corrupted_x = x + self.noise_level * torch.randn_like(x)
return corrupted_x
def decoder(self, x):
h2 = self.fc2(x)
return h2
def forward(self, x):
out = self.mask(x)
encode = self.encoder(out)
decode = self.decoder(encode)
return encode, decode
class PositionalEncoding(nn.Module):
def __init__(self, feature_len, feature_size, dropout=0.0):
'''
Args:
feature_len: the feature length of input data (required).
feature_size: the feature size of input data (required).
dropout: the dropout rate (optional).
'''
super(PositionalEncoding, self).__init__()

pe = torch.zeros(feature_len, feature_size)
position = torch.arange(0, feature_len, dtype=torch.float).unsqueeze(1)

div_term = torch.exp(torch.arange(0, feature_size, 2).float() * (-math.log(10000.0) / feature_size))

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

pe = pe.unsqueeze(0)

self.register_buffer('pe', pe)

def forward(self, x):
x = x + self.pe
return x

class Net(nn.Module):
def __init__(self, feature_size=16, hidden_dim=32, feature_num=1, num_layers=1, nhead=1, dropout=0.0, noise_level=0.01):
'''
Args:
feature_size: the feature size of input data (required).
hidden_dim: the hidden size of Transformer block (required).
feature_num: the number of features, such as capacity, voltage, and current; set 1 for only sigle feature (optional).
num_layers: the number of layers of Transformer block (optional).
nhead: the number of heads of multi-attention in Transformer block (optional).
dropout: the dropout rate of Transformer block (optional).
noise_level: the noise level added in Autoencoder (optional).
'''
super(Net, self).__init__()
self.auto_hidden = int(feature_size / 2)
input_size = self.auto_hidden
if feature_num == 1:
# 当为单个特性(如仅容量数据)建模时,Transformer被视为编码器(Encoder)
self.pos = PositionalEncoding(feature_len=feature_num, feature_size=input_size)
encoder_layers = nn.TransformerEncoderLayer(d_model=input_size, nhead=nhead, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
elif feature_num > 1:
# 在对容量、电压和电流数据等多特征进行建模时,Transformer被视为序列模型(sequence model)
self.pos = PositionalEncoding(feature_len=input_size, feature_size=feature_num)
encoder_layers = nn.TransformerEncoderLayer(d_model=feature_num, nhead=nhead, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
self.cell = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
self.linear = nn.Linear(feature_num*self.auto_hidden, 1)
self.autoencoder = Autoencoder(input_size=feature_size, hidden_dim=self.auto_hidden, noise_level=noise_level)
def forward(self, x):
batch_size, feature_num, feature_size = x.shape
out, decode = self.autoencoder(x)
if feature_num > 1:
out = out.reshape(batch_size, -1, feature_num)
out = self.pos(out)
out = self.cell(out) # 单一特征: (batch_size, feature_num, auto_hidden) or 多特征: (batch_size, auto_hidden, feature_num)
out = out.reshape(batch_size, -1) # (batch_size, feature_num*auto_hidden)
out = self.linear(out) # out shape: (batch_size, 1)
return out, decode
代码
文本

4.2 定义训练函数

代码
文本

在处理时间序列预测问题时,特别是在使用如Transformer这样的模型时,fixed predictionmoving prediction这两种策略经常被提及。这两种方法处理输入数据的方式不同,针对的应用场景也有所不同。以下是这两种策略的详细解释:

  • Fixed Prediction (固定预测)

在"fixed prediction"策略中,模型使用固定长度的历史数据来预测未来的一点或多点数据。这种方法通常在模型训练时定义一个固定的窗口,该窗口包含了一段时间序列的数据,窗口的大小在训练过程中不会改变。每次预测都是基于相同长度的历史数据进行的。

例如,如果设置的窗口长度为30天,无论预测哪一天的数据,都会使用前30天的数据作为输入。预测完成后,窗口不会向前滑动,而是固定选取接下来需要预测的时间点的前30天数据进行预测。

这种策略适用于那些历史数据长度固定且历史信息足以预测未来数据点的场景。

  • Moving Prediction (滑动预测)

"Moving prediction"策略使用一种动态的窗口来预测时间序列中的数据点。在这种方法中,窗口会随着时间的推移而移动。每完成一次预测,窗口就会向前滑动一个或多个时间单位,使得每次预测都包括最新的数据。

例如,如果窗口长度同样设为30天,模型首先使用第1天到第30天的数据预测第31天。在下一步预测第32天时,窗口向前滑动,此时使用第2天到第31天的数据作为输入。

这种策略使模型能够不断地利用最新的数据进行预测,非常适合于那些数据快速变化或者模型需要频繁更新以适应最新数据趋势的场景。

简言之:

  • 在实施固定预测时,可以将历史数据固定长度的时间序列直接输入到模型中。
  • 对于滑动预测,则需要每次预测后更新输入数据窗口,包括最近的观测值。

选择哪种预测策略取决于具体的需求和数据特性。固定预测可能更简单、计算成本较低,而滑动预测则能更灵活地适应数据的变化,可能更适合于环境变化较快的应用场景。

代码
文本
[16]
def train(lr=0.01, feature_size=8, feature_num=1, hidden_dim=32, num_layers=1, nhead=1, dropout=0.0, epochs=1000,
weight_decay=0.0, seed=0, alpha=0.0, noise_level=0.0, metric='re', device='cpu'):
'''
Args:
lr: learning rate for training (required).
feature_size: the feature size of input data (required).
feature_num: the number of features, such as capacity, voltage, and current; set 1 for only sigle feature (optional).
hidden_dim: the hidden size of Transformer block (required).
num_layers: the number of layers of Transformer block (optional).
nhead: the number of heads of multi-attention in Transformer block (optional).
dropout: the dropout rate of Transformer block (optional).
epochs:
weight_decay:
seed: (optional).
alpha: (optional).
noise_level: the noise level added in Autoencoder (optional).
metric: (optional).
device: the device for training (optional).
'''
score_list, fixed_result_list, moving_result_list = [], [], []
setup_seed(seed)
for i in range(4):
name = Battery_list[i]
train_x, train_y, train_data, test_data = get_train_test(Battery, name, feature_size)
test_sequence = train_data + test_data
# print('sample size: {}'.format(len(train_x)))

model = Net(feature_size=feature_size, hidden_dim=hidden_dim, feature_num=K, num_layers=num_layers,
nhead=nhead, dropout=dropout, noise_level=noise_level)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.MSELoss()

test_x = train_data.copy()
loss_list, y_fixed_slice, y_moving_slice = [0], [], []
rmse, re = 1, 1
score_, score = [1],[1]
for epoch in range(epochs):
x, y = np.reshape(train_x/Rated_Capacity,(-1, feature_num, feature_size)), np.reshape(train_y/Rated_Capacity,(-1,1))
x, y = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)
x = x.repeat(1, K, 1)
output, decode = model(x)
output = output.reshape(-1, 1)
loss = criterion(output, y) + alpha * criterion(decode, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (epoch + 1)%10 == 0:
test_x = train_data.copy()
fixed_point_list, moving_point_list = [], []
t = 0
while (len(test_x) - len(train_data)) < len(test_data):
x = np.reshape(np.array(test_x[-feature_size:])/Rated_Capacity,(-1, feature_num, feature_size)).astype(np.float32)
x = torch.from_numpy(x).to(device)
x = x.repeat(1, K, 1)
pred, _ = model(x)
next_point = pred.data.cpu().numpy()[0,0] * Rated_Capacity
test_x.append(next_point) # 测试值被添加到原始序列中,以继续预测下一个点
fixed_point_list.append(next_point) # 保存输出序列中最后一个点的预测值
x = np.reshape(np.array(test_sequence[t:t+feature_size])/Rated_Capacity,(-1, 1, feature_size)).astype(np.float32)
x = torch.from_numpy(x).to(device)
x = x.repeat(1, K, 1)
pred, _ = model(x)
next_point = pred.data.cpu().numpy()[0,0] * Rated_Capacity
moving_point_list.append(next_point) # 保存输出序列中最后一个点的预测值
t += 1
y_fixed_slice.append(fixed_point_list) # 保存预测值
y_moving_slice.append(moving_point_list)

loss_list.append(loss)
rmse = evaluation(y_test=test_data, y_predict=y_fixed_slice[-1])
re = relative_error(y_test=test_data, y_predict=y_fixed_slice[-1], threshold=Rated_Capacity*0.7)
#print('epoch:{:<2d} | loss:{:<6.4f} | RMSE:{:<6.4f} | RE:{:<6.4f}'.format(epoch, loss, rmse, re))
if metric == 're':
score = [re]
elif metric == 'rmse':
score = [rmse]
else:
score = [re, rmse]
if (loss < 1e-3) and (score_[0] < score[0]):
break
score_ = score.copy()
score_list.append(score_)
fixed_result_list.append(train_data.copy() + y_fixed_slice[-1])
moving_result_list.append(train_data.copy() + y_moving_slice[-1])
return score_list, fixed_result_list, moving_result_list
代码
文本

4.3 训练并查看模型性能

设置训练的参数并查看模型在训练中的性能

代码
文本
[17]
# ---------- Time warning: ~7mins ----------
Rated_Capacity = 2.0
feature_size = 16 * 3
feature_num = 1
dropout = 0.0
epochs = 2000
nhead = 1
hidden_dim = 256
num_layers = 1
lr = 0.005
weight_decay = 0.0
noise_level = 0.01
alpha = 0.01
metric = 'rmse'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

K = 16

SCORE = []
for seed in tqdm(range(5)):
print('seed:{}'.format(seed))
score_list, _, _ = train(lr=lr, feature_size=feature_size, feature_num=feature_num, hidden_dim=hidden_dim, num_layers=num_layers,
nhead=nhead, weight_decay=weight_decay, epochs=epochs, seed=seed, dropout=dropout, alpha=alpha,
noise_level=noise_level, metric=metric, device=device)
print(np.array(score_list))
print(metric + 'for this seed: {:<6.4f}'.format(np.mean(np.array(score_list))))
for s in score_list:
SCORE.append(s)
print('------------------------------------------------------------------')
print(metric + ' mean: {:<6.4f}'.format(np.mean(np.array(SCORE))))
seed:0
[[[0.02581599 0.01961029]]

 [[0.11350003 0.10178691]]

 [[0.09362362 0.08794732]]

 [[0.03510362 0.02962641]]]
rmsefor this seed: 0.0634
------------------------------------------------------------------
seed:1
[[[0.02323237 0.01700326]]

 [[0.11040415 0.09982523]]

 [[0.09569618 0.09030358]]

 [[0.03432236 0.02757088]]]
rmsefor this seed: 0.0623
------------------------------------------------------------------
seed:2
[[[0.02547026 0.01917356]]

 [[0.11211098 0.10056742]]

 [[0.08777632 0.08186543]]

 [[0.03331153 0.02759026]]]
rmsefor this seed: 0.0610
------------------------------------------------------------------
seed:3
[[[0.02827207 0.01933416]]

 [[0.11172008 0.09856118]]

 [[0.08325723 0.07706927]]

 [[0.03393557 0.0278173 ]]]
rmsefor this seed: 0.0600
------------------------------------------------------------------
seed:4
[[[0.02583933 0.01974082]]

 [[0.11073539 0.09894526]]

 [[0.0933731  0.08755475]]

 [[0.03368701 0.02845385]]]
rmsefor this seed: 0.0623
------------------------------------------------------------------
rmse mean: 0.0618
代码
文本

代码解读

  • 初始化参数:设置模型训练的超参数,如学习率、隐藏层维度、层数、训练轮数等。
  • 多次实验:通过改变随机种子(seed)进行多次训练和评估,以评估模型的稳健性
  • 调用 train 函数:使用不同的随机种子进行训练和评估,并获取评估结果。
  • 收集和打印评估结果:将每次评估的结果收集起来,并计算和打印每个种子和所有种子的平均评估结果。
代码
文本

4.4 模型预测

对测试数据进行预测并显示模型对未知数据的预测能力

代码
文本
[18]
Rated_Capacity = 2.0
feature_size = 16 * 3
feature_num = 1
dropout = 0.0
epochs = 2000
nhead = 1
hidden_dim = 256
num_layers = 1
lr = 0.005
weight_decay = 0.0
noise_level = 0.01
alpha = 0.01
metric = 'rmse'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
K = 16
seed = 0

SCORE = []
print('seed:{}'.format(seed))
score_list, fixed_result_list, moving_result_list = train(lr=lr, feature_size=feature_size, feature_num=feature_num,
hidden_dim=hidden_dim, num_layers=num_layers, nhead=nhead,
weight_decay=weight_decay, epochs=epochs, seed=seed, dropout=dropout,
alpha=alpha, noise_level=noise_level, metric=metric, device=device)
print(np.array(score_list))
print(metric + 'for this seed: {:<6.4f}'.format(np.mean(np.array(score_list))))
seed:0
[[[0.02581599 0.01961029]]

 [[0.11350003 0.10178691]]

 [[0.09362362 0.08794732]]

 [[0.03510362 0.02962641]]]
rmsefor this seed: 0.0634
代码
文本

代码解读

  • 初始化参数:设置与前一部分相同的超参数。
  • 设置特定种子:使用固定的随机种子 seed = 0 进行模型训练和评估
  • 调用 train 函数:使用固定的随机种子进行训练和评估,并获取评估结果。
  • 打印评估结果:打印固定随机种子下的评估结果。

代码
文本
  • 绘制RUL图
代码
文本
[19]
fig, ax = plt.subplots(2, 2, figsize=(24, 18))

for i in range(2):
for j in range(2):
index = i * 2 + j # 计算电池的索引
battery_name = Battery_list[index]
test_data = Battery[battery_name]['capacity'][1]
fixed_predict_data = fixed_result_list[index]
moving_predict_data = moving_result_list[index]
x = [t for t in range(len(test_data))]
threshold = [Rated_Capacity*0.7] * len(test_data)
ax[i][j].plot(x, test_data, 'c', label='test data')
ax[i][j].plot(x, fixed_predict_data, 'b', label='fixed predicted data')
ax[i][j].plot(x, moving_predict_data, 'r', label='moving predicted data')
ax[i][j].plot(x, threshold, 'black', ls=':', label='stop line')
ax[i][j].legend()
ax[i][j].set_xlabel('Discharge cycles', fontsize=20)
ax[i][j].set_ylabel('Capacity (Ah)', fontsize=20)
ax[i][j].set_title('test v.s. prediction of battery ' + battery_name, fontsize=20)
plt.show()

代码
文本

4.5 超参数搜索(选学)

下面我们将使用超参数搜索,找到一组最佳的超参数来优化模型的性能。具体来说,代码通过不同的超参数组合来训练模型,并记录每次训练的性能得分,最终选择得分最优的超参数组合。

参数筛选的过程将会比较费时,请酌情考虑运行

代码
文本
[20]
# 定义一些固定的超参数
Rated_Capacity = 2.0
feature_size = 16 * 3
feature_num = 1
dropout = 0.0
epochs = 2000
nhead = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
weight_decay = 0.0
noise_level = 0.05
alpha = 0.0
metric = 'rmse'

# 下方的多层for循环搜索的就是超参数搜索范围
states = {}
for K in tqdm([16, 32, 64, 128], position=0):
for lr in tqdm([0.001, 0.01], position=1):
for num_layers in tqdm([1, 2], position=2):
for hidden_dim in [16, 32, 64, 128, 256]:
for alpha in [0.0, 1e-4, 1e-3, 1e-2]:
show_str = 'K={}, lr={}, num_layers={}, hidden_dim={}, alpha={}'.format(K, lr, num_layers, hidden_dim, alpha)
print(show_str)
SCORE = []
# 多次训练模型并计算平均得分
for seed in range(3):
print('seed:{}'.format(seed))
score_list, _, _ = train(lr=lr, feature_size=feature_size, feature_num=feature_num, hidden_dim=hidden_dim,
num_layers=num_layers, nhead=nhead, weight_decay=weight_decay, epochs=epochs,
seed=seed, dropout=dropout, alpha=alpha, noise_level=noise_level, metric=metric, device=device)
print(np.array(score_list))
print(metric + ': {:<6.4f}'.format(np.mean(np.array(score_list))))
print('------------------------------------------------------------------')
for s in score_list:
SCORE.append(s)
# 记录每组超参数的平均得分
print(metric + ' mean: {:<6.4f}'.format(np.mean(np.array(SCORE))))
states[show_str] = np.mean(np.array(SCORE))
print('===================================================================')
# 找到最优超参数组合
min_key = min(states, key = states.get)
print('optimal parameters: {}, result: {}'.format(min_key, states[min_key]))
K=16, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.0
seed:0
[[[0.05579332 0.04315475]]

 [[0.05537865 0.04317957]]

 [[0.07546735 0.06422637]]

 [[0.05290992 0.04336866]]]
rmse: 0.0542
------------------------------------------------------------------
seed:1
[[[0.06108889 0.04962513]]

 [[0.14630799 0.12177661]]

 [[0.07382508 0.06543761]]

 [[0.05291493 0.0433477 ]]]
rmse: 0.0768
------------------------------------------------------------------
seed:2
[[[0.03834202 0.02976019]]

 [[0.11673085 0.09145127]]

 [[0.07471715 0.06662953]]

 [[0.05680966 0.04726888]]]
rmse: 0.0652
------------------------------------------------------------------
rmse mean: 0.0654
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.0001
seed:0
[[[0.05574509 0.04311358]]

 [[0.05537706 0.04318089]]

 [[0.07545332 0.06421305]]

 [[0.05289318 0.0434147 ]]]
rmse: 0.0542
------------------------------------------------------------------
seed:1
[[[0.06086272 0.04942671]]

 [[0.14630212 0.12177026]]

 [[0.07393592 0.06554996]]

 [[0.05290475 0.04334061]]]
rmse: 0.0768
------------------------------------------------------------------
seed:2
[[[0.03830367 0.02980537]]

 [[0.11666724 0.09140835]]

 [[0.07469732 0.06660853]]

 [[0.0568037  0.04726652]]]
rmse: 0.0652
------------------------------------------------------------------
rmse mean: 0.0654
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.001
seed:0
[[[0.05529232 0.04273052]]

 [[0.05535384 0.04317414]]

 [[0.09775404 0.0923337 ]]

 [[0.0486982  0.03950913]]]
rmse: 0.0594
------------------------------------------------------------------
seed:1
[[[0.05887469 0.04767337]]

 [[0.08935789 0.07665426]]

 [[0.11695443 0.10975647]]

 [[0.04486671 0.0340712 ]]]
rmse: 0.0723
------------------------------------------------------------------
seed:2
[[[0.03845274 0.02990091]]

 [[0.11607002 0.09101668]]

 [[0.07438442 0.06626677]]

 [[0.0493777  0.04054882]]]
rmse: 0.0633
------------------------------------------------------------------
rmse mean: 0.0650
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.01
seed:0
[[[0.05243552 0.04029038]]

 [[0.13129685 0.12207231]]

 [[0.08075634 0.07377641]]

 [[0.04194189 0.03464847]]]
rmse: 0.0722
------------------------------------------------------------------
seed:1
[[[0.04639063 0.03666479]]

 [[0.09595033 0.08775964]]

 [[0.10594106 0.10167632]]

 [[0.04167628 0.03287641]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.03973843 0.03132249]]

 [[0.11183965 0.08828685]]

 [[0.07570422 0.06684503]]

 [[0.04013189 0.03212339]]]
rmse: 0.0607
------------------------------------------------------------------
rmse mean: 0.0672
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.0
seed:0
[[[0.04570398 0.03386932]]

 [[0.12039536 0.09798004]]

 [[0.10341691 0.09502885]]

 [[0.04746935 0.03816094]]]
rmse: 0.0728
------------------------------------------------------------------
seed:1
[[[0.06050094 0.05062599]]

 [[0.0972739  0.08322365]]

 [[0.10408007 0.09364621]]

 [[0.04814552 0.03988793]]]
rmse: 0.0722
------------------------------------------------------------------
seed:2
[[[0.04146468 0.0314298 ]]

 [[0.123384   0.10860474]]

 [[0.11228388 0.10771195]]

 [[0.04837707 0.03925201]]]
rmse: 0.0766
------------------------------------------------------------------
rmse mean: 0.0738
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.0001
seed:0
[[[0.06054141 0.05068043]]

 [[0.11534359 0.10038401]]

 [[0.11345214 0.10886452]]

 [[0.04597465 0.03469426]]]
rmse: 0.0787
------------------------------------------------------------------
seed:2
[[[0.04140137 0.03138902]]

 [[0.12360062 0.10883281]]

 [[0.11341541 0.10700375]]

 [[0.04678257 0.03692625]]]
rmse: 0.0762
------------------------------------------------------------------
rmse mean: 0.0769
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.01
seed:0
[[[0.03511723 0.02673222]]

 [[0.1274418  0.10370036]]

 [[0.08737518 0.07994064]]

 [[0.04188141 0.03540233]]]
rmse: 0.0672
------------------------------------------------------------------
seed:1
[[[0.05844361 0.04880038]]

 [[0.10531176 0.09531503]]

 [[0.10926066 0.10394125]]

 [[0.0417639  0.03410057]]]
rmse: 0.0746
------------------------------------------------------------------
seed:2
[[[0.03529982 0.02860416]]

 [[0.11015119 0.09272921]]

 [[0.10865732 0.1047193 ]]

 [[0.04211071 0.03545042]]]
rmse: 0.0697
------------------------------------------------------------------
rmse mean: 0.0705
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.0
seed:0
[[[0.04534212 0.03386404]]

 [[0.10791959 0.08572897]]

 [[0.07808951 0.06958386]]

 [[0.04171728 0.03176094]]]
rmse: 0.0618
------------------------------------------------------------------
seed:1
[[[0.04565693 0.03657514]]

 [[0.11563183 0.09775207]]

 [[0.08858664 0.0843206 ]]

 [[0.04798819 0.03947968]]]
rmse: 0.0695
------------------------------------------------------------------
seed:2
[[[0.04102911 0.03282395]]

 [[0.11723924 0.09648926]]

 [[0.07672825 0.06597762]]

 [[0.05263588 0.04189266]]]
rmse: 0.0656
------------------------------------------------------------------
rmse mean: 0.0656
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.0001
seed:0
[[[0.04533991 0.0338627 ]]

 [[0.10783408 0.08559674]]

 [[0.07806664 0.06955769]]

 [[0.04171351 0.03175744]]]
rmse: 0.0617
------------------------------------------------------------------
seed:1
[[[0.04564766 0.03656737]]

 [[0.11562354 0.09774777]]

 [[0.08856247 0.08429537]]

 [[0.04797971 0.03944457]]]
rmse: 0.0695
------------------------------------------------------------------
seed:2
[[[0.04101948 0.03282229]]

 [[0.11722409 0.09648552]]

 [[0.07673157 0.06598853]]

 [[0.05262419 0.04188973]]]
rmse: 0.0656
------------------------------------------------------------------
rmse mean: 0.0656
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.001
seed:0
[[[0.04532497 0.03386072]]

 [[0.1097264  0.0870697 ]]

 [[0.07800225 0.06948435]]

 [[0.04169096 0.0317376 ]]]
rmse: 0.0621
------------------------------------------------------------------
seed:1
[[[0.04554575 0.03646945]]

 [[0.115531   0.0976982 ]]

 [[0.08832138 0.08403683]]

 [[0.04642754 0.03794871]]]
rmse: 0.0690
------------------------------------------------------------------
seed:2
[[[0.04095001 0.03271809]]

 [[0.11981773 0.09929855]]

 [[0.10264766 0.09566698]]

 [[0.03916255 0.03098665]]]
rmse: 0.0702
------------------------------------------------------------------
rmse mean: 0.0671
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.01
seed:0
[[[0.04500372 0.03364014]]

 [[0.08893182 0.07582743]]

 [[0.11320525 0.10759558]]

 [[0.05430474 0.04017501]]]
rmse: 0.0698
------------------------------------------------------------------
seed:1
[[[0.04466283 0.03534823]]

 [[0.09493656 0.08530461]]

 [[0.10234367 0.09556175]]

 [[0.03685407 0.03011271]]]
rmse: 0.0656
------------------------------------------------------------------
seed:2
[[[0.03931845 0.03217196]]

 [[0.12173005 0.09852856]]

 [[0.12051482 0.1163262 ]]

 [[0.04263701 0.03504451]]]
rmse: 0.0758
------------------------------------------------------------------
rmse mean: 0.0704
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.0
seed:0
[[[0.05212979 0.04201493]]

 [[0.09530859 0.08056868]]

 [[0.07597009 0.0661597 ]]

 [[0.04236682 0.03371838]]]
rmse: 0.0610
------------------------------------------------------------------
seed:1
[[[0.04451324 0.03411513]]

 [[0.11550815 0.09959015]]

 [[0.09776229 0.09046725]]

 [[0.0444946  0.03639191]]]
rmse: 0.0704
------------------------------------------------------------------
seed:2
[[[0.04719653 0.03759071]]

 [[0.11632335 0.09532064]]

 [[0.11798803 0.11156472]]

 [[0.03966085 0.0308577 ]]]
rmse: 0.0746
------------------------------------------------------------------
rmse mean: 0.0686
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.0001
seed:0
[[[0.05212189 0.0420022 ]]

 [[0.09531499 0.08057523]]

 [[0.07595878 0.06614809]]

 [[0.04236246 0.03371331]]]
rmse: 0.0610
------------------------------------------------------------------
seed:1
[[[0.04451287 0.03412137]]

 [[0.11546416 0.09955514]]

 [[0.09772985 0.09043223]]

 [[0.04450534 0.03640591]]]
rmse: 0.0703
------------------------------------------------------------------
seed:2
[[[0.04678944 0.03725688]]

 [[0.11632198 0.09532239]]

 [[0.1179869  0.11156077]]

 [[0.0396561  0.03085956]]]
rmse: 0.0745
------------------------------------------------------------------
rmse mean: 0.0686
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.001
seed:0
[[[0.05200423 0.04184149]]

 [[0.09543564 0.08069439]]

 [[0.07589671 0.06607275]]

 [[0.04233232 0.03367831]]]
rmse: 0.0610
------------------------------------------------------------------
seed:1
[[[0.0444775  0.03413049]]

 [[0.11522758 0.09936584]]

 [[0.09747154 0.09017688]]

 [[0.04433844 0.03626341]]]
rmse: 0.0702
------------------------------------------------------------------
seed:2
[[[0.04467206 0.03574815]]

 [[0.11637435 0.09539056]]

 [[0.11986627 0.11499337]]

 [[0.04292482 0.03552765]]]
rmse: 0.0757
------------------------------------------------------------------
rmse mean: 0.0690
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.01
seed:0
[[[0.05065351 0.04122391]]

 [[0.11885741 0.09724139]]

 [[0.09917449 0.09217287]]

 [[0.04011896 0.03379625]]]
rmse: 0.0717
------------------------------------------------------------------
seed:1
[[[0.04464794 0.03435011]]

 [[0.11847323 0.1031361 ]]

 [[0.10504644 0.09926197]]

 [[0.04610906 0.03699228]]]
rmse: 0.0735
------------------------------------------------------------------
seed:2
[[[0.03909549 0.03051162]]

 [[0.12815029 0.11080443]]

 [[0.10116868 0.09206224]]

 [[0.04665912 0.03719984]]]
rmse: 0.0732
------------------------------------------------------------------
rmse mean: 0.0728
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.0
seed:0
[[[0.05391202 0.04345764]]

 [[0.10498295 0.09132787]]

 [[0.10684148 0.10087558]]

 [[0.04801265 0.03696902]]]
rmse: 0.0733
------------------------------------------------------------------
seed:1
[[[0.04192736 0.03322934]]

 [[0.11879614 0.09513601]]

 [[0.11830103 0.11304637]]

 [[0.04462205 0.03680698]]]
rmse: 0.0752
------------------------------------------------------------------
seed:2
[[[0.04856485 0.03899068]]

 [[0.10630403 0.08547351]]

 [[0.11887025 0.1133377 ]]

 [[0.04267112 0.03538458]]]
rmse: 0.0737
------------------------------------------------------------------
rmse mean: 0.0741
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.0001
seed:0
[[[0.05377548 0.04332595]]

 [[0.10499666 0.09134202]]

 [[0.10582882 0.10021984]]

 [[0.04800543 0.03696646]]]
rmse: 0.0731
------------------------------------------------------------------
seed:1
[[[0.04192207 0.03321832]]

 [[0.11878484 0.09513316]]

 [[0.11821773 0.11295857]]

 [[0.04461848 0.03680627]]]
rmse: 0.0752
------------------------------------------------------------------
seed:2
[[[0.04855426 0.03898024]]

 [[0.10519716 0.08448152]]

 [[0.10841588 0.10575112]]

 [[0.04518044 0.0353064 ]]]
rmse: 0.0715
------------------------------------------------------------------
rmse mean: 0.0732
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.001
seed:0
[[[0.05258222 0.04219308]]

 [[0.10505006 0.09141212]]

 [[0.10513123 0.09995746]]

 [[0.04798879 0.03695702]]]
rmse: 0.0727
------------------------------------------------------------------
seed:1
[[[0.04188319 0.03313094]]

 [[0.11867309 0.09508975]]

 [[0.11816546 0.11264889]]

 [[0.04454601 0.03677245]]]
rmse: 0.0751
------------------------------------------------------------------
seed:2
[[[0.04841836 0.03884982]]

 [[0.10619751 0.086536  ]]

 [[0.1192516  0.11037633]]

 [[0.04009299 0.03203733]]]
rmse: 0.0727
------------------------------------------------------------------
rmse mean: 0.0735
===================================================================
K=16, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.01
seed:0
[[[0.03763078 0.02837536]]

 [[0.106052   0.0926899 ]]

 [[0.10215394 0.09876973]]

 [[0.04144    0.03479739]]]
rmse: 0.0677
------------------------------------------------------------------
seed:1
[[[0.03977969 0.0309912 ]]

 [[0.11833074 0.09524053]]

 [[0.11739133 0.11261137]]

 [[0.04970058 0.03963456]]]
rmse: 0.0755
------------------------------------------------------------------
seed:2
[[[0.03421232 0.02682436]]

 [[0.12004388 0.10439169]]

 [[0.0889178  0.07972656]]

 [[0.03862722 0.03122593]]]
rmse: 0.0655
------------------------------------------------------------------
rmse mean: 0.0696
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.0
seed:0
[[[0.05673177 0.04157464]]

 [[0.12559007 0.10151312]]

 [[0.1107996  0.10172874]]

 [[0.04524496 0.03599722]]]
rmse: 0.0774
------------------------------------------------------------------
seed:1
[[[0.05871615 0.04644778]]

 [[0.11642618 0.10027111]]

 [[0.11440261 0.1080416 ]]

 [[0.04826621 0.03917692]]]
rmse: 0.0790
------------------------------------------------------------------
seed:2
[[[0.04191858 0.03396169]]

 [[0.10605103 0.09186653]]

 [[0.12087109 0.11530858]]

 [[0.04394851 0.03679544]]]
rmse: 0.0738
------------------------------------------------------------------
rmse mean: 0.0767
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.0001
seed:0
[[[0.05668564 0.04153814]]

 [[0.11692718 0.09663428]]

 [[0.12094239 0.11635993]]

 [[0.04301932 0.03508179]]]
rmse: 0.0784
------------------------------------------------------------------
seed:1
[[[0.0585751  0.04630955]]

 [[0.11666116 0.10042868]]

 [[0.11440283 0.10804101]]

 [[0.04829926 0.03919853]]]
rmse: 0.0790
------------------------------------------------------------------
seed:2
[[[0.04187298 0.0339219 ]]

 [[0.10606917 0.09189143]]

 [[0.12082612 0.11527656]]

 [[0.04393222 0.03678755]]]
rmse: 0.0738
------------------------------------------------------------------
rmse mean: 0.0771
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.001
seed:0
[[[0.05628807 0.04122052]]

 [[0.11618582 0.09628513]]

 [[0.12092468 0.11634516]]

 [[0.04300218 0.03510858]]]
rmse: 0.0782
------------------------------------------------------------------
seed:1
[[[0.05730631 0.04507395]]

 [[0.11665753 0.10072128]]

 [[0.11441309 0.10804288]]

 [[0.04859621 0.03936328]]]
rmse: 0.0788
------------------------------------------------------------------
seed:2
[[[0.04170427 0.03376439]]

 [[0.10608109 0.09195354]]

 [[0.12051295 0.1150346 ]]

 [[0.04376938 0.03668106]]]
rmse: 0.0737
------------------------------------------------------------------
rmse mean: 0.0769
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.01
seed:0
[[[0.04540047 0.03658669]]

 [[0.11958484 0.11140124]]

 [[0.10643895 0.10107296]]

 [[0.04753242 0.03864825]]]
rmse: 0.0758
------------------------------------------------------------------
seed:1
[[[0.04811769 0.03628626]]

 [[0.12147899 0.11190593]]

 [[0.07333762 0.06502914]]

 [[0.04800285 0.03833371]]]
rmse: 0.0678
------------------------------------------------------------------
seed:2
[[[0.03414814 0.02624882]]

 [[0.1116393  0.0978976 ]]

 [[0.11265252 0.10965097]]

 [[0.03837403 0.03130283]]]
rmse: 0.0702
------------------------------------------------------------------
rmse mean: 0.0713
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.0
seed:0
[[[0.04266422 0.03230684]]

 [[0.10865701 0.09406076]]

 [[0.11849782 0.1136968 ]]

 [[0.03871115 0.03044414]]]
rmse: 0.0724
------------------------------------------------------------------
seed:1
[[[0.0534995  0.04389478]]

 [[0.07072577 0.05687882]]

 [[0.08669062 0.0790183 ]]

 [[0.04879506 0.03903839]]]
rmse: 0.0598
------------------------------------------------------------------
seed:2
[[[0.031701   0.02485635]]

 [[0.06732967 0.0509512 ]]

 [[0.12039325 0.11486324]]

 [[0.0472803  0.03894043]]]
rmse: 0.0620
------------------------------------------------------------------
rmse mean: 0.0647
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.0001
seed:0
[[[0.04263564 0.03230625]]

 [[0.1086197  0.09407699]]

 [[0.11833453 0.11340539]]

 [[0.03868783 0.03042714]]]
rmse: 0.0723
------------------------------------------------------------------
seed:1
[[[0.05346437 0.04386232]]

 [[0.07071158 0.05687685]]

 [[0.08673077 0.07906397]]

 [[0.04876282 0.03901758]]]
rmse: 0.0598
------------------------------------------------------------------
seed:2
[[[0.0316945  0.02484877]]

 [[0.067326   0.05095165]]

 [[0.12021853 0.11475956]]

 [[0.04726709 0.0389287 ]]]
rmse: 0.0620
------------------------------------------------------------------
rmse mean: 0.0647
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.001
seed:0
[[[0.04252445 0.03231103]]

 [[0.12022469 0.09713373]]

 [[0.11891971 0.11477692]]

 [[0.04613985 0.03579704]]]
rmse: 0.0760
------------------------------------------------------------------
seed:1
[[[0.05318751 0.04359475]]

 [[0.10782958 0.09806873]]

 [[0.09483133 0.08944917]]

 [[0.04114511 0.03291902]]]
rmse: 0.0701
------------------------------------------------------------------
seed:2
[[[0.03162328 0.02477932]]

 [[0.06728733 0.05091358]]

 [[0.12089712 0.11576166]]

 [[0.03880954 0.03076825]]]
rmse: 0.0601
------------------------------------------------------------------
rmse mean: 0.0687
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.01
seed:0
[[[0.04319579 0.03581565]]

 [[0.12509669 0.10229312]]

 [[0.11877094 0.11424648]]

 [[0.04163065 0.03158149]]]
rmse: 0.0766
------------------------------------------------------------------
seed:1
[[[0.05015558 0.04066414]]

 [[0.13652842 0.12763038]]

 [[0.08797549 0.081234  ]]

 [[0.04015491 0.03272211]]]
rmse: 0.0746
------------------------------------------------------------------
seed:2
[[[0.03158881 0.02460055]]

 [[0.11508462 0.10492855]]

 [[0.06625308 0.05495791]]

 [[0.03722513 0.0306584 ]]]
rmse: 0.0582
------------------------------------------------------------------
rmse mean: 0.0698
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.0
seed:0
[[[0.03777994 0.02894404]]

 [[0.11353611 0.08952327]]

 [[0.11358654 0.10811452]]

 [[0.04488127 0.03657433]]]
rmse: 0.0716
------------------------------------------------------------------
seed:1
[[[0.04346425 0.03360963]]

 [[0.11156499 0.09492539]]

 [[0.11904327 0.11318134]]

 [[0.04516424 0.03647492]]]
rmse: 0.0747
------------------------------------------------------------------
seed:2
[[[0.05426267 0.04258715]]

 [[0.11472149 0.09422189]]

 [[0.10578323 0.10181898]]

 [[0.04405719 0.0338681 ]]]
rmse: 0.0739
------------------------------------------------------------------
rmse mean: 0.0734
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.0001
seed:0
[[[0.03777268 0.02893953]]

 [[0.1135103  0.0895268 ]]

 [[0.11350492 0.10803096]]

 [[0.0448593  0.03655527]]]
rmse: 0.0716
------------------------------------------------------------------
seed:1
[[[0.04345268 0.03360062]]

 [[0.11157017 0.09492737]]

 [[0.1189939  0.11314414]]

 [[0.04540739 0.03741525]]]
rmse: 0.0748
------------------------------------------------------------------
seed:2
[[[0.0542106  0.04254339]]

 [[0.11470901 0.09419824]]

 [[0.10572559 0.10176627]]

 [[0.04404142 0.03387433]]]
rmse: 0.0739
------------------------------------------------------------------
rmse mean: 0.0734
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.001
seed:0
[[[0.03759829 0.02883432]]

 [[0.11338668 0.08962227]]

 [[0.1125951  0.10708639]]

 [[0.04464774 0.0363523 ]]]
rmse: 0.0713
------------------------------------------------------------------
seed:1
[[[0.04341906 0.03359309]]

 [[0.1115941  0.09493443]]

 [[0.11863551 0.11289522]]

 [[0.04334923 0.0348499 ]]]
rmse: 0.0742
------------------------------------------------------------------
seed:2
[[[0.0540868  0.04240035]]

 [[0.11468903 0.09411378]]

 [[0.10537891 0.10140933]]

 [[0.04300747 0.03363154]]]
rmse: 0.0736
------------------------------------------------------------------
rmse mean: 0.0730
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.01
seed:0
[[[0.04935352 0.03849522]]

 [[0.1068779  0.09452731]]

 [[0.11947545 0.11595692]]

 [[0.04758375 0.03880273]]]
rmse: 0.0764
------------------------------------------------------------------
seed:1
[[[0.04341291 0.03360955]]

 [[0.11180669 0.09500829]]

 [[0.11518325 0.11023909]]

 [[0.04481645 0.03714941]]]
rmse: 0.0739
------------------------------------------------------------------
seed:2
[[[0.03821056 0.02992193]]

 [[0.11404354 0.10154661]]

 [[0.11278727 0.1091244 ]]

 [[0.04697413 0.03726721]]]
rmse: 0.0737
------------------------------------------------------------------
rmse mean: 0.0747
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.0
seed:0
[[[0.06086408 0.04848302]]

 [[0.11512577 0.10057141]]

 [[0.12326453 0.11893706]]

 [[0.0379096  0.03059444]]]
rmse: 0.0795
------------------------------------------------------------------
seed:1
[[[0.03641448 0.0274479 ]]

 [[0.1240222  0.099063  ]]

 [[0.10585587 0.09914909]]

 [[0.03936852 0.03158455]]]
rmse: 0.0704
------------------------------------------------------------------
seed:2
[[[0.02933787 0.02358929]]

 [[0.11234647 0.09411393]]

 [[0.08985053 0.08110247]]

 [[0.04043466 0.03317214]]]
rmse: 0.0630
------------------------------------------------------------------
rmse mean: 0.0709
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.0001
seed:0
[[[0.06086302 0.04848966]]

 [[0.11511701 0.1005835 ]]

 [[0.1232263  0.11891386]]

 [[0.03790631 0.03059521]]]
rmse: 0.0795
------------------------------------------------------------------
seed:1
[[[0.0364106  0.02744414]]

 [[0.10906755 0.08867598]]

 [[0.10574394 0.09647241]]

 [[0.04370443 0.03587168]]]
rmse: 0.0679
------------------------------------------------------------------
seed:2
[[[0.02931112 0.02357618]]

 [[0.11236691 0.09413142]]

 [[0.08988596 0.08115114]]

 [[0.0404324  0.03317087]]]
rmse: 0.0630
------------------------------------------------------------------
rmse mean: 0.0701
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.001
seed:0
[[[0.06084387 0.04852153]]

 [[0.11488356 0.10056297]]

 [[0.08748839 0.08004492]]

 [[0.04836499 0.04124934]]]
rmse: 0.0727
------------------------------------------------------------------
seed:1
[[[0.0363886  0.02741066]]

 [[0.10883737 0.0885521 ]]

 [[0.10712573 0.10093666]]

 [[0.04095378 0.03195845]]]
rmse: 0.0678
------------------------------------------------------------------
seed:2
[[[0.04516672 0.03657855]]

 [[0.1189353  0.09980283]]

 [[0.1230556  0.11848043]]

 [[0.04369242 0.03289908]]]
rmse: 0.0773
------------------------------------------------------------------
rmse mean: 0.0726
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.01
seed:0
[[[0.04487392 0.03694725]]

 [[0.11154256 0.08766794]]

 [[0.08983154 0.0795729 ]]

 [[0.03815397 0.02969262]]]
rmse: 0.0648
------------------------------------------------------------------
seed:1
[[[0.03641821 0.02757619]]

 [[0.10545847 0.09586354]]

 [[0.11699208 0.11389214]]

 [[0.04602126 0.03842123]]]
rmse: 0.0726
------------------------------------------------------------------
seed:2
[[[0.0364609  0.02872061]]

 [[0.1096242  0.09550244]]

 [[0.12024906 0.10906501]]

 [[0.03524365 0.02829293]]]
rmse: 0.0704
------------------------------------------------------------------
rmse mean: 0.0693
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.0
seed:0
[[[0.03339027 0.02734725]]

 [[0.12143007 0.09858313]]

 [[0.11952118 0.11376786]]

 [[0.04583852 0.03552888]]]
rmse: 0.0744
------------------------------------------------------------------
seed:1
[[[0.03039028 0.02357632]]

 [[0.11174028 0.08995344]]

 [[0.13658024 0.13126929]]

 [[0.03893294 0.0306989 ]]]
rmse: 0.0741
------------------------------------------------------------------
seed:2
[[[0.04784628 0.03576165]]

 [[0.11163317 0.09498105]]

 [[0.11037274 0.10614765]]

 [[0.03984564 0.03162836]]]
rmse: 0.0723
------------------------------------------------------------------
rmse mean: 0.0736
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.0001
seed:0
[[[0.03325576 0.02727397]]

 [[0.1214266  0.09858532]]

 [[0.11950421 0.11376219]]

 [[0.04584231 0.03553269]]]
rmse: 0.0744
------------------------------------------------------------------
seed:1
[[[0.03038324 0.02357343]]

 [[0.11173127 0.08995305]]

 [[0.1104977  0.10601776]]

 [[0.03896205 0.03055741]]]
rmse: 0.0677
------------------------------------------------------------------
seed:2
[[[0.04784518 0.03575777]]

 [[0.11159172 0.09495699]]

 [[0.11036341 0.10614294]]

 [[0.03984268 0.03162684]]]
rmse: 0.0723
------------------------------------------------------------------
rmse mean: 0.0715
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.001
seed:0
[[[0.05505658 0.0468531 ]]

 [[0.1028916  0.07974401]]

 [[0.08249829 0.07191263]]

 [[0.04137076 0.03374815]]]
rmse: 0.0643
------------------------------------------------------------------
seed:1
[[[0.03030452 0.02341137]]

 [[0.11163565 0.08992666]]

 [[0.1104682  0.10599113]]

 [[0.03902844 0.03059509]]]
rmse: 0.0677
------------------------------------------------------------------
seed:2
[[[0.04780339 0.03570366]]

 [[0.11090286 0.09452524]]

 [[0.11028171 0.1061031 ]]

 [[0.03998131 0.03173675]]]
rmse: 0.0721
------------------------------------------------------------------
rmse mean: 0.0680
===================================================================
K=16, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.01
seed:0
[[[0.04715497 0.04046337]]

 [[0.0984418  0.08598601]]

 [[0.09700634 0.09340842]]

 [[0.03938988 0.02917126]]]
rmse: 0.0664
------------------------------------------------------------------
seed:1
[[[0.0300261  0.02295509]]

 [[0.11086942 0.09995349]]

 [[0.10558508 0.10094836]]

 [[0.0423479  0.03307483]]]
rmse: 0.0682
------------------------------------------------------------------
seed:2
[[[0.02994895 0.02215727]]

 [[0.11331868 0.09561749]]

 [[0.08263468 0.07327432]]

 [[0.04244897 0.03275425]]]
rmse: 0.0615
------------------------------------------------------------------
rmse mean: 0.0654
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.0
seed:0
[[[0.03760317 0.02782361]]

 [[0.11411925 0.09617055]]

 [[0.08385795 0.08105066]]

 [[0.03626964 0.02902591]]]
rmse: 0.0632
------------------------------------------------------------------
seed:1
[[[0.0302898  0.02107317]]

 [[0.11271366 0.09795506]]

 [[0.11505239 0.11209881]]

 [[0.0419407  0.03418222]]]
rmse: 0.0707
------------------------------------------------------------------
seed:2
[[[0.02797515 0.01907183]]

 [[0.11032128 0.09033501]]

 [[0.11573333 0.11160446]]

 [[0.03488613 0.02901298]]]
rmse: 0.0674
------------------------------------------------------------------
rmse mean: 0.0671
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.0001
seed:0
[[[0.03760154 0.02782179]]

 [[0.11409795 0.09615541]]

 [[0.08386486 0.08105816]]

 [[0.03626987 0.02902615]]]
rmse: 0.0632
------------------------------------------------------------------
seed:1
[[[0.03028967 0.02107307]]

 [[0.1127132  0.09795496]]

 [[0.1150531  0.11209972]]

 [[0.04193885 0.03418072]]]
rmse: 0.0707
------------------------------------------------------------------
seed:2
[[[0.02797996 0.01907334]]

 [[0.11030883 0.09032312]]

 [[0.11573961 0.11160971]]

 [[0.0349393  0.02908129]]]
rmse: 0.0674
------------------------------------------------------------------
rmse mean: 0.0671
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.001
seed:0
[[[0.03758975 0.02781235]]

 [[0.11410746 0.09617261]]

 [[0.08387155 0.08106739]]

 [[0.03626916 0.02902523]]]
rmse: 0.0632
------------------------------------------------------------------
seed:1
[[[0.0302011  0.02102285]]

 [[0.11269206 0.09794704]]

 [[0.11505528 0.11210111]]

 [[0.04192345 0.03416953]]]
rmse: 0.0706
------------------------------------------------------------------
seed:2
[[[0.02801676 0.01909114]]

 [[0.11036577 0.09038193]]

 [[0.11581958 0.11168286]]

 [[0.03909818 0.030766  ]]]
rmse: 0.0682
------------------------------------------------------------------
rmse mean: 0.0673
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.01
seed:0
[[[0.03750971 0.02773803]]

 [[0.11424247 0.09636146]]

 [[0.08391559 0.08111217]]

 [[0.03626833 0.02902259]]]
rmse: 0.0633
------------------------------------------------------------------
seed:1
[[[0.03015422 0.02100027]]

 [[0.11246993 0.09782449]]

 [[0.11764614 0.11384965]]

 [[0.03533688 0.02914678]]]
rmse: 0.0697
------------------------------------------------------------------
seed:2
[[[0.02831783 0.01924977]]

 [[0.11068492 0.09072032]]

 [[0.11647748 0.11225031]]

 [[0.0388854  0.02954337]]]
rmse: 0.0683
------------------------------------------------------------------
rmse mean: 0.0671
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.0
seed:0
[[[0.03156553 0.02224434]]

 [[0.107492   0.09779408]]

 [[0.1024619  0.09973717]]

 [[0.04060489 0.03322953]]]
rmse: 0.0669
------------------------------------------------------------------
seed:1
[[[0.02768732 0.0212465 ]]

 [[0.11024914 0.1005482 ]]

 [[0.10685108 0.10514091]]

 [[0.04255708 0.03461595]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.03257075 0.02289294]]

 [[0.11146603 0.09639744]]

 [[0.1116352  0.11016572]]

 [[0.03444337 0.02715588]]]
rmse: 0.0683
------------------------------------------------------------------
rmse mean: 0.0679
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.0001
seed:0
[[[0.03156547 0.02224444]]

 [[0.1074934  0.09779484]]

 [[0.1024642  0.09973992]]

 [[0.04060414 0.03322898]]]
rmse: 0.0669
------------------------------------------------------------------
seed:1
[[[0.02768755 0.0212483 ]]

 [[0.1102611  0.1005572 ]]

 [[0.10685522 0.10514526]]

 [[0.04255763 0.03461665]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.03257128 0.02289319]]

 [[0.11143341 0.09638428]]

 [[0.11163657 0.11016727]]

 [[0.03444058 0.02715759]]]
rmse: 0.0683
------------------------------------------------------------------
rmse mean: 0.0679
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.001
seed:0
[[[0.03156487 0.02224398]]

 [[0.10750825 0.09779623]]

 [[0.10248071 0.09976144]]

 [[0.04058792 0.03321731]]]
rmse: 0.0669
------------------------------------------------------------------
seed:1
[[[0.02769258 0.02124629]]

 [[0.11197485 0.09687523]]

 [[0.10629113 0.10481536]]

 [[0.03626747 0.02878706]]]
rmse: 0.0667
------------------------------------------------------------------
seed:2
[[[0.03257756 0.02289554]]

 [[0.11129511 0.09635116]]

 [[0.11163575 0.11016659]]

 [[0.0344487  0.02715585]]]
rmse: 0.0683
------------------------------------------------------------------
rmse mean: 0.0673
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.01
seed:0
[[[0.03155507 0.02223344]]

 [[0.10765549 0.09784513]]

 [[0.10273527 0.10008633]]

 [[0.04042186 0.0330576 ]]]
rmse: 0.0669
------------------------------------------------------------------
seed:1
[[[0.02771695 0.02123464]]

 [[0.11337045 0.09767734]]

 [[0.10641563 0.104686  ]]

 [[0.0425072  0.03456387]]]
rmse: 0.0685
------------------------------------------------------------------
seed:2
[[[0.03263642 0.02292211]]

 [[0.11186673 0.09648448]]

 [[0.11172077 0.1102458 ]]

 [[0.0345374  0.02721479]]]
rmse: 0.0685
------------------------------------------------------------------
rmse mean: 0.0680
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.0
seed:0
[[[0.0307165  0.02359253]]

 [[0.12693212 0.10497252]]

 [[0.11952836 0.11476451]]

 [[0.0327869  0.02740154]]]
rmse: 0.0726
------------------------------------------------------------------
seed:1
[[[0.02661342 0.01803923]]

 [[0.10960256 0.09658329]]

 [[0.12079884 0.11537179]]

 [[0.0339941  0.02786568]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.04971522 0.04012228]]

 [[0.12484853 0.10268443]]

 [[0.11388183 0.1072357 ]]

 [[0.03955823 0.03164546]]]
rmse: 0.0762
------------------------------------------------------------------
rmse mean: 0.0725
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.0001
seed:0
[[[0.03071929 0.02359627]]

 [[0.12693204 0.10497257]]

 [[0.11953016 0.11476729]]

 [[0.03278694 0.02740162]]]
rmse: 0.0726
------------------------------------------------------------------
seed:1
[[[0.02661321 0.01803917]]

 [[0.10960146 0.09658022]]

 [[0.12080398 0.11537665]]

 [[0.0339943  0.0278658 ]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.03257872 0.02188176]]

 [[0.11233416 0.08894548]]

 [[0.1202601  0.11517538]]

 [[0.03512764 0.02886471]]]
rmse: 0.0694
------------------------------------------------------------------
rmse mean: 0.0702
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.001
seed:0
[[[0.03072284 0.02360002]]

 [[0.12692762 0.10496939]]

 [[0.11952791 0.11476688]]

 [[0.03278637 0.02740044]]]
rmse: 0.0726
------------------------------------------------------------------
seed:1
[[[0.02661388 0.0180386 ]]

 [[0.10959831 0.09657702]]

 [[0.12080148 0.11537312]]

 [[0.03399486 0.02786471]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.03207945 0.02551651]]

 [[0.11034543 0.09956747]]

 [[0.09169318 0.08302984]]

 [[0.03513245 0.02886939]]]
rmse: 0.0633
------------------------------------------------------------------
rmse mean: 0.0682
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.01
seed:0
[[[0.03072419 0.02356549]]

 [[0.12691986 0.10496533]]

 [[0.11909822 0.11455332]]

 [[0.032787   0.02739165]]]
rmse: 0.0725
------------------------------------------------------------------
seed:1
[[[0.0266477  0.01805785]]

 [[0.10958736 0.096557  ]]

 [[0.12076025 0.11536372]]

 [[0.03392477 0.02779174]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.02668835 0.02041247]]

 [[0.11871646 0.10136721]]

 [[0.11784939 0.1144969 ]]

 [[0.03559325 0.02838941]]]
rmse: 0.0704
------------------------------------------------------------------
rmse mean: 0.0705
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.0
seed:0
[[[0.04237631 0.03824248]]

 [[0.11149518 0.10219393]]

 [[0.0905429  0.083574  ]]

 [[0.03614227 0.02940698]]]
rmse: 0.0667
------------------------------------------------------------------
seed:1
[[[0.03072141 0.0230357 ]]

 [[0.11330777 0.09754582]]

 [[0.1188732  0.11236345]]

 [[0.03586232 0.02890239]]]
rmse: 0.0701
------------------------------------------------------------------
seed:2
[[[0.03285232 0.02402508]]

 [[0.10728824 0.10018903]]

 [[0.12023569 0.11440817]]

 [[0.034069   0.02793012]]]
rmse: 0.0701
------------------------------------------------------------------
rmse mean: 0.0690
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.0001
seed:0
[[[0.02944121 0.02303998]]

 [[0.12274526 0.10190746]]

 [[0.05053952 0.04192693]]

 [[0.03497413 0.02950716]]]
rmse: 0.0543
------------------------------------------------------------------
seed:1
[[[0.03071907 0.02303569]]

 [[0.11326011 0.09753785]]

 [[0.11886897 0.11236108]]

 [[0.03585753 0.02889817]]]
rmse: 0.0701
------------------------------------------------------------------
seed:2
[[[0.0328528  0.02402539]]

 [[0.10727188 0.10017722]]

 [[0.12029689 0.11443015]]

 [[0.03406927 0.02793038]]]
rmse: 0.0701
------------------------------------------------------------------
rmse mean: 0.0648
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.001
seed:0
[[[0.02935545 0.02287345]]

 [[0.12276921 0.10192306]]

 [[0.04394124 0.03829088]]

 [[0.03527651 0.02874562]]]
rmse: 0.0529
------------------------------------------------------------------
seed:1
[[[0.03069051 0.02301176]]

 [[0.11276994 0.09753651]]

 [[0.11883323 0.11234635]]

 [[0.03585246 0.02889172]]]
rmse: 0.0700
------------------------------------------------------------------
seed:2
[[[0.03283498 0.02400877]]

 [[0.10729603 0.10020541]]

 [[0.12022594 0.11440815]]

 [[0.0340697  0.02793006]]]
rmse: 0.0701
------------------------------------------------------------------
rmse mean: 0.0643
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.01
seed:0
[[[0.02963667 0.02049373]]

 [[0.11155169 0.10233493]]

 [[0.08990269 0.08311729]]

 [[0.03736606 0.03085235]]]
rmse: 0.0632
------------------------------------------------------------------
seed:1
[[[0.03047868 0.02297797]]

 [[0.11193528 0.09746751]]

 [[0.11852902 0.11228101]]

 [[0.03579278 0.02884481]]]
rmse: 0.0698
------------------------------------------------------------------
seed:2
[[[0.03275866 0.02394689]]

 [[0.10754027 0.10053406]]

 [[0.12009583 0.11436174]]

 [[0.03416341 0.02794623]]]
rmse: 0.0702
------------------------------------------------------------------
rmse mean: 0.0677
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.0
seed:0
[[[0.0308985  0.02157277]]

 [[0.1132746  0.09947547]]

 [[0.09628003 0.09294042]]

 [[0.03826308 0.030989  ]]]
rmse: 0.0655
------------------------------------------------------------------
seed:1
[[[0.027646   0.02038637]]

 [[0.1147467  0.10128794]]

 [[0.11660912 0.11397524]]

 [[0.03302176 0.02781271]]]
rmse: 0.0694
------------------------------------------------------------------
seed:2
[[[0.02671651 0.01999157]]

 [[0.11002655 0.09785576]]

 [[0.02713914 0.01699354]]

 [[0.03485296 0.028494  ]]]
rmse: 0.0453
------------------------------------------------------------------
rmse mean: 0.0601
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.0001
seed:0
[[[0.03091998 0.02159358]]

 [[0.11325451 0.09945647]]

 [[0.09629156 0.09295393]]

 [[0.0386367  0.03266149]]]
rmse: 0.0657
------------------------------------------------------------------
seed:1
[[[0.02816471 0.02139015]]

 [[0.11471967 0.10122575]]

 [[0.11660979 0.11397591]]

 [[0.03302157 0.0278111 ]]]
rmse: 0.0696
------------------------------------------------------------------
seed:2
[[[0.02671516 0.01999336]]

 [[0.11005631 0.09786466]]

 [[0.02984713 0.02022866]]

 [[0.03485837 0.0284948 ]]]
rmse: 0.0460
------------------------------------------------------------------
rmse mean: 0.0604
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.001
seed:0
[[[0.03092279 0.02159537]]

 [[0.12226051 0.0974913 ]]

 [[0.09510205 0.09096995]]

 [[0.0346556  0.02970811]]]
rmse: 0.0653
------------------------------------------------------------------
seed:1
[[[0.02865095 0.02264041]]

 [[0.11496955 0.10120192]]

 [[0.11663558 0.11399161]]

 [[0.03301966 0.02779778]]]
rmse: 0.0699
------------------------------------------------------------------
seed:2
[[[0.02842207 0.02213053]]

 [[0.04140481 0.03264281]]

 [[0.04201618 0.0370477 ]]

 [[0.03479755 0.02809854]]]
rmse: 0.0333
------------------------------------------------------------------
rmse mean: 0.0562
===================================================================
K=16, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.01
seed:0
[[[0.03093864 0.02159376]]

 [[0.07005971 0.06289832]]

 [[0.12512694 0.11817716]]

 [[0.03489581 0.02831566]]]
rmse: 0.0615
------------------------------------------------------------------
seed:1
[[[0.0294126  0.02321308]]

 [[0.11505331 0.10007051]]

 [[0.11679132 0.11408124]]

 [[0.03300158 0.02766426]]]
rmse: 0.0699
------------------------------------------------------------------
seed:2
[[[0.02673434 0.0202473 ]]

 [[0.11015667 0.09786172]]

 [[0.02950061 0.02204448]]

 [[0.03476212 0.02877822]]]
rmse: 0.0463
------------------------------------------------------------------
rmse mean: 0.0592
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.0
seed:0
[[[0.03233259 0.02605617]]

 [[0.11256649 0.10448702]]

 [[0.10116338 0.09911471]]

 [[0.0346998  0.02834553]]]
rmse: 0.0673
------------------------------------------------------------------
seed:1
[[[0.03041122 0.0204677 ]]

 [[0.04168246 0.03518302]]

 [[0.12040899 0.11425913]]

 [[0.03425222 0.02860738]]]
rmse: 0.0532
------------------------------------------------------------------
seed:2
[[[0.02858094 0.02008633]]

 [[0.10689571 0.096064  ]]

 [[0.09966879 0.09824012]]

 [[0.03421267 0.02870722]]]
rmse: 0.0641
------------------------------------------------------------------
rmse mean: 0.0615
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.0001
seed:0
[[[0.03237042 0.02611753]]

 [[0.11256404 0.10448098]]

 [[0.10116363 0.09911486]]

 [[0.03469816 0.02834654]]]
rmse: 0.0674
------------------------------------------------------------------
seed:1
[[[0.03041125 0.02046831]]

 [[0.03981308 0.03084388]]

 [[0.14075435 0.13693645]]

 [[0.03308028 0.02724564]]]
rmse: 0.0574
------------------------------------------------------------------
seed:2
[[[0.02857825 0.02007919]]

 [[0.10706327 0.09590399]]

 [[0.10255095 0.0910606 ]]

 [[0.0329725  0.02798342]]]
rmse: 0.0633
------------------------------------------------------------------
rmse mean: 0.0627
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.001
seed:0
[[[0.03217084 0.02586386]]

 [[0.11250699 0.10438783]]

 [[0.10117494 0.09912534]]

 [[0.03468387 0.02835803]]]
rmse: 0.0673
------------------------------------------------------------------
seed:1
[[[0.03041751 0.02046862]]

 [[0.06803352 0.06130475]]

 [[0.09692067 0.09413389]]

 [[0.03387538 0.02748181]]]
rmse: 0.0541
------------------------------------------------------------------
seed:2
[[[0.02856159 0.02004485]]

 [[0.10696351 0.0964317 ]]

 [[0.10279779 0.09766107]]

 [[0.03417162 0.02782006]]]
rmse: 0.0643
------------------------------------------------------------------
rmse mean: 0.0619
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.01
seed:0
[[[0.02854276 0.02152457]]

 [[0.11467818 0.09834173]]

 [[0.08295501 0.07753251]]

 [[0.03271432 0.02632617]]]
rmse: 0.0603
------------------------------------------------------------------
seed:1
[[[0.03046558 0.02048401]]

 [[0.11848281 0.10053115]]

 [[0.08821333 0.08264863]]

 [[0.03362754 0.02816299]]]
rmse: 0.0628
------------------------------------------------------------------
seed:2
[[[0.02843942 0.01979977]]

 [[0.1071619  0.09528702]]

 [[0.05321265 0.0431394 ]]

 [[0.03585873 0.03104501]]]
rmse: 0.0517
------------------------------------------------------------------
rmse mean: 0.0583
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.0
seed:0
[[[0.02997921 0.02102445]]

 [[0.10907377 0.09928606]]

 [[0.11906663 0.11435549]]

 [[0.03521054 0.03037072]]]
rmse: 0.0698
------------------------------------------------------------------
seed:1
[[[0.02495424 0.01861517]]

 [[0.11002245 0.09806319]]

 [[0.11834596 0.11696011]]

 [[0.03399686 0.02734416]]]
rmse: 0.0685
------------------------------------------------------------------
seed:2
[[[0.04406757 0.03662939]]

 [[0.05689577 0.04962217]]

 [[0.02955387 0.0234203 ]]

 [[0.03498678 0.02908416]]]
rmse: 0.0380
------------------------------------------------------------------
rmse mean: 0.0588
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.0001
seed:0
[[[0.02998145 0.02101452]]

 [[0.10905098 0.09926525]]

 [[0.11906986 0.11435709]]

 [[0.03521872 0.03038307]]]
rmse: 0.0698
------------------------------------------------------------------
seed:1
[[[0.02570422 0.02009136]]

 [[0.10999669 0.09801982]]

 [[0.10412309 0.1024235 ]]

 [[0.03399414 0.02734263]]]
rmse: 0.0652
------------------------------------------------------------------
seed:2
[[[0.11390963 0.09272588]]

 [[0.11320129 0.09986873]]

 [[0.08378011 0.07423742]]

 [[0.03650365 0.02726209]]]
rmse: 0.0802
------------------------------------------------------------------
rmse mean: 0.0717
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.001
seed:0
[[[0.03030939 0.02077913]]

 [[0.1091541  0.09933327]]

 [[0.11908089 0.11435812]]

 [[0.03517054 0.03015694]]]
rmse: 0.0698
------------------------------------------------------------------
seed:1
[[[0.07115097 0.06102061]]

 [[0.1159914  0.0990829 ]]

 [[0.1494737  0.14399405]]

 [[0.03502589 0.02891508]]]
rmse: 0.0881
------------------------------------------------------------------
seed:2
[[[0.03398584 0.02651199]]

 [[0.11091739 0.10155964]]

 [[0.0869046  0.08038223]]

 [[0.03476979 0.02888511]]]
rmse: 0.0630
------------------------------------------------------------------
rmse mean: 0.0736
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.01
seed:0
[[[0.02933996 0.02140497]]

 [[0.11083309 0.10073136]]

 [[0.0894921  0.08316307]]

 [[0.0370613  0.03201713]]]
rmse: 0.0630
------------------------------------------------------------------
seed:1
[[[0.0380994  0.03111693]]

 [[0.11083908 0.09862193]]

 [[0.14677916 0.13948689]]

 [[0.03397909 0.02887403]]]
rmse: 0.0785
------------------------------------------------------------------
seed:2
[[[0.02986365 0.02473096]]

 [[0.098428   0.08852948]]

 [[0.02516203 0.01669813]]

 [[0.03988242 0.03458134]]]
rmse: 0.0447
------------------------------------------------------------------
rmse mean: 0.0621
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.0
seed:0
[[[0.11885835 0.10154086]]

 [[0.10481663 0.09559224]]

 [[0.0719928  0.0667422 ]]

 [[0.03401911 0.0289243 ]]]
rmse: 0.0778
------------------------------------------------------------------
seed:1
[[[0.04694378 0.04219887]]

 [[0.1190169  0.09945997]]

 [[0.07833241 0.072525  ]]

 [[0.10045981 0.0907306 ]]]
rmse: 0.0812
------------------------------------------------------------------
seed:2
[[[0.08626024 0.07329038]]

 [[0.11458098 0.10583259]]

 [[0.035966   0.02708433]]

 [[0.06926777 0.05240023]]]
rmse: 0.0706
------------------------------------------------------------------
rmse mean: 0.0765
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.0001
seed:0
[[[0.11805014 0.10082217]]

 [[0.10488122 0.09564308]]

 [[0.08260697 0.07709882]]

 [[0.03401959 0.02892559]]]
rmse: 0.0803
------------------------------------------------------------------
seed:1
[[[0.04783729 0.04174824]]

 [[0.0416918  0.03327018]]

 [[0.12271358 0.11317806]]

 [[0.04344642 0.03387861]]]
rmse: 0.0597
------------------------------------------------------------------
seed:2
[[[0.08593896 0.07302676]]

 [[0.05120482 0.0445109 ]]

 [[0.10651217 0.10200846]]

 [[0.03605661 0.03151548]]]
rmse: 0.0663
------------------------------------------------------------------
rmse mean: 0.0688
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.001
seed:0
[[[0.1236811  0.1056067 ]]

 [[0.10563552 0.09626365]]

 [[0.02419818 0.01551554]]

 [[0.03460754 0.02936475]]]
rmse: 0.0669
------------------------------------------------------------------
seed:1
[[[0.07475463 0.06411395]]

 [[0.1131553  0.09971681]]

 [[0.08488559 0.08071127]]

 [[0.03394166 0.02826297]]]
rmse: 0.0724
------------------------------------------------------------------
seed:2
[[[0.08626884 0.07329634]]

 [[0.04867039 0.03457914]]

 [[0.12374019 0.11534961]]

 [[0.03409633 0.02732557]]]
rmse: 0.0679
------------------------------------------------------------------
rmse mean: 0.0691
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.01
seed:0
[[[0.02599613 0.01702121]]

 [[0.11496283 0.09902683]]

 [[0.07069416 0.0647052 ]]

 [[0.03366667 0.02706097]]]
rmse: 0.0566
------------------------------------------------------------------
seed:1
[[[0.03415614 0.03038041]]

 [[0.11449531 0.1033513 ]]

 [[0.10232942 0.10076857]]

 [[0.08552069 0.07780097]]]
rmse: 0.0811
------------------------------------------------------------------
seed:2
[[[0.0896808  0.0759861 ]]

 [[0.11576266 0.08957003]]

 [[0.08877594 0.08593864]]

 [[0.03479631 0.02848008]]]
rmse: 0.0761
------------------------------------------------------------------
rmse mean: 0.0713
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.0
seed:0
[[[0.03156517 0.02492164]]

 [[0.12120867 0.10294164]]

 [[0.1186484  0.10699645]]

 [[0.06660107 0.06034576]]]
rmse: 0.0792
------------------------------------------------------------------
seed:1
[[[0.13017768 0.10803466]]

 [[0.12520232 0.10687816]]

 [[0.06999115 0.06713333]]

 [[0.03813085 0.0310965 ]]]
rmse: 0.0846
------------------------------------------------------------------
seed:2
[[[0.02970129 0.0188762 ]]

 [[0.12993679 0.10724889]]

 [[0.1019906  0.09737821]]

 [[0.03741052 0.03312665]]]
rmse: 0.0695
------------------------------------------------------------------
rmse mean: 0.0777
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.0001
seed:0
[[[0.11916724 0.11756178]]

 [[0.17215363 0.14283625]]

 [[0.10332605 0.09744344]]

 [[0.10016661 0.09429329]]]
rmse: 0.1184
------------------------------------------------------------------
seed:1
[[[0.13125563 0.10942685]]

 [[0.11897932 0.09962778]]

 [[0.07167568 0.05908975]]

 [[0.03437252 0.02952957]]]
rmse: 0.0817
------------------------------------------------------------------
seed:2
[[[0.03395577 0.03008401]]

 [[0.11439719 0.1026843 ]]

 [[0.09097755 0.08745205]]

 [[0.06472556 0.05749544]]]
rmse: 0.0727
------------------------------------------------------------------
rmse mean: 0.0909
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.001
seed:0
[[[0.07267587 0.06227633]]

 [[0.07222401 0.06625302]]

 [[0.03393975 0.02856358]]

 [[0.07989508 0.07236335]]]
rmse: 0.0610
------------------------------------------------------------------
seed:1
[[[0.03026421 0.02583864]]

 [[0.11621688 0.10626747]]

 [[0.09503528 0.08550904]]

 [[0.06452435 0.05616109]]]
rmse: 0.0725
------------------------------------------------------------------
seed:2
[[[0.05015717 0.04408518]]

 [[0.11367438 0.10148312]]

 [[0.08963288 0.08606521]]

 [[0.04026681 0.03173471]]]
rmse: 0.0696
------------------------------------------------------------------
rmse mean: 0.0677
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.01
seed:0
[[[0.04007761 0.02713956]]

 [[0.11200056 0.09890663]]

 [[0.08541722 0.06966205]]

 [[0.05257328 0.03981595]]]
rmse: 0.0657
------------------------------------------------------------------
seed:1
[[[0.108926   0.09264493]]

 [[0.11082068 0.10244453]]

 [[0.05320968 0.04568659]]

 [[0.0475739  0.04038382]]]
rmse: 0.0752
------------------------------------------------------------------
seed:2
[[[0.10210102 0.08814974]]

 [[0.11624474 0.09968924]]

 [[0.0904324  0.08341346]]

 [[0.04726727 0.03565926]]]
rmse: 0.0829
------------------------------------------------------------------
rmse mean: 0.0746
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.0
seed:0
[[[0.02776393 0.02214674]]

 [[0.16244339 0.1340213 ]]

 [[0.08957023 0.08363101]]

 [[0.05469874 0.04538459]]]
rmse: 0.0775
------------------------------------------------------------------
seed:1
[[[0.13531604 0.11680063]]

 [[0.17669651 0.14709626]]

 [[0.14194237 0.12816987]]

 [[0.13263396 0.12883987]]]
rmse: 0.1384
------------------------------------------------------------------
seed:2
[[[0.02664352 0.01869047]]

 [[0.17388192 0.14446234]]

 [[0.15229622 0.11795508]]

 [[0.03947488 0.03350089]]]
rmse: 0.0884
------------------------------------------------------------------
rmse mean: 0.1014
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.0001
seed:0
[[[0.0315103  0.02616035]]

 [[0.10638789 0.09323395]]

 [[0.03434804 0.0252349 ]]

 [[0.03866622 0.03309332]]]
rmse: 0.0486
------------------------------------------------------------------
seed:1
[[[0.13540811 0.11688099]]

 [[0.16722741 0.13815644]]

 [[0.15239187 0.11795286]]

 [[0.03598273 0.0314372 ]]]
rmse: 0.1119
------------------------------------------------------------------
seed:2
[[[0.05704053 0.04777524]]

 [[0.17375865 0.14434513]]

 [[0.09174843 0.08905709]]

 [[0.03362916 0.02900568]]]
rmse: 0.0833
------------------------------------------------------------------
rmse mean: 0.0813
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.001
seed:0
[[[0.12505764 0.11148768]]

 [[0.04352708 0.03254633]]

 [[0.11716293 0.11548292]]

 [[0.08167102 0.07411641]]]
rmse: 0.0876
------------------------------------------------------------------
seed:1
[[[0.13541908 0.11688858]]

 [[0.17424355 0.14482984]]

 [[0.15216994 0.11788622]]

 [[0.04014152 0.03332559]]]
rmse: 0.1144
------------------------------------------------------------------
seed:2
[[[0.02371155 0.01737667]]

 [[0.0704607  0.05213472]]

 [[0.62441568 0.61660305]]

 [[0.03748577 0.03258741]]]
rmse: 0.1843
------------------------------------------------------------------
rmse mean: 0.1288
===================================================================
K=16, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.01
seed:0
[[[0.05060483 0.04315425]]

 [[0.1647557  0.13511914]]

 [[0.27738479 0.25652828]]

 [[0.03987895 0.03302791]]]
rmse: 0.1251
------------------------------------------------------------------
seed:1
[[[0.13540299 0.11687613]]

 [[0.17418653 0.1447705 ]]

 [[0.15204398 0.1178743 ]]

 [[0.03908773 0.03227769]]]
rmse: 0.1141
------------------------------------------------------------------
seed:2
[[[0.05358292 0.04896742]]

 [[0.08906069 0.07316073]]

 [[0.06574866 0.05272185]]

 [[0.09157192 0.07982532]]]
rmse: 0.0693
------------------------------------------------------------------
rmse mean: 0.1028
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.0
seed:0
[[[0.0419174  0.03152137]]

 [[0.12674137 0.10100168]]

 [[0.07085089 0.06074013]]

 [[0.04859356 0.03960291]]]
rmse: 0.0651
------------------------------------------------------------------
seed:1
[[[0.04316346 0.0336876 ]]

 [[0.13244767 0.10597091]]

 [[0.06939991 0.06014254]]

 [[0.0745369  0.06222527]]]
rmse: 0.0727
------------------------------------------------------------------
seed:2
[[[0.0373737  0.0303129 ]]

 [[0.10998818 0.09637215]]

 [[0.10969973 0.10426526]]

 [[0.05551989 0.0461785 ]]]
rmse: 0.0737
------------------------------------------------------------------
rmse mean: 0.0705
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.0001
seed:0
[[[0.04191855 0.03152363]]

 [[0.12670706 0.10097626]]

 [[0.07098577 0.06084908]]

 [[0.0485904  0.0396009 ]]]
rmse: 0.0651
------------------------------------------------------------------
seed:1
[[[0.04316517 0.03368856]]

 [[0.13239998 0.10594037]]

 [[0.06924568 0.05997539]]

 [[0.07479275 0.06250392]]]
rmse: 0.0727
------------------------------------------------------------------
seed:2
[[[0.03731558 0.03024599]]

 [[0.10999594 0.09638503]]

 [[0.10964637 0.10421803]]

 [[0.05553402 0.04618964]]]
rmse: 0.0737
------------------------------------------------------------------
rmse mean: 0.0705
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.001
seed:0
[[[0.04191184 0.03152574]]

 [[0.10344858 0.08890575]]

 [[0.08434869 0.07707513]]

 [[0.03911995 0.03106881]]]
rmse: 0.0622
------------------------------------------------------------------
seed:1
[[[0.04312171 0.03367853]]

 [[0.13209533 0.10580845]]

 [[0.06928952 0.06004134]]

 [[0.06947207 0.057459  ]]]
rmse: 0.0714
------------------------------------------------------------------
seed:2
[[[0.03685611 0.02978476]]

 [[0.10996461 0.09631731]]

 [[0.10919365 0.10382385]]

 [[0.05546216 0.04619599]]]
rmse: 0.0734
------------------------------------------------------------------
rmse mean: 0.0690
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.01
seed:0
[[[0.04156496 0.03129091]]

 [[0.11031534 0.10386775]]

 [[0.09747177 0.09156431]]

 [[0.04633938 0.03901625]]]
rmse: 0.0702
------------------------------------------------------------------
seed:1
[[[0.03631134 0.02836399]]

 [[0.10705537 0.09911123]]

 [[0.09216878 0.08352941]]

 [[0.04639604 0.03605447]]]
rmse: 0.0661
------------------------------------------------------------------
seed:2
[[[0.03757047 0.0292134 ]]

 [[0.11810039 0.10564102]]

 [[0.10627268 0.10041626]]

 [[0.04791714 0.0374561 ]]]
rmse: 0.0728
------------------------------------------------------------------
rmse mean: 0.0697
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.0
seed:0
[[[0.045328   0.03565908]]

 [[0.11537146 0.09701207]]

 [[0.09877598 0.08662214]]

 [[0.0552317  0.04737069]]]
rmse: 0.0727
------------------------------------------------------------------
seed:1
[[[0.04592084 0.03661861]]

 [[0.09030368 0.07952608]]

 [[0.09720892 0.08709383]]

 [[0.04877128 0.03973406]]]
rmse: 0.0656
------------------------------------------------------------------
seed:2
[[[0.03657518 0.02936733]]

 [[0.07577485 0.06166312]]

 [[0.07145902 0.05683952]]

 [[0.04675788 0.03670967]]]
rmse: 0.0519
------------------------------------------------------------------
rmse mean: 0.0634
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.0001
seed:0
[[[0.04532413 0.03566302]]

 [[0.11540666 0.09704799]]

 [[0.09874911 0.08659936]]

 [[0.05557122 0.04769998]]]
rmse: 0.0728
------------------------------------------------------------------
seed:1
[[[0.04592214 0.0366209 ]]

 [[0.0903218  0.07956121]]

 [[0.09719563 0.08708089]]

 [[0.04877728 0.03973029]]]
rmse: 0.0657
------------------------------------------------------------------
seed:2
[[[0.03655682 0.02935407]]

 [[0.07576049 0.06165705]]

 [[0.07145605 0.05683804]]

 [[0.04669654 0.0366535 ]]]
rmse: 0.0519
------------------------------------------------------------------
rmse mean: 0.0634
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.001
seed:0
[[[0.04554864 0.03593597]]

 [[0.12161386 0.10365847]]

 [[0.07011291 0.060173  ]]

 [[0.04688661 0.03701029]]]
rmse: 0.0651
------------------------------------------------------------------
seed:1
[[[0.04594024 0.03665667]]

 [[0.08997078 0.07923085]]

 [[0.09708263 0.08697563]]

 [[0.04887869 0.03974299]]]
rmse: 0.0656
------------------------------------------------------------------
seed:2
[[[0.03638399 0.02923445]]

 [[0.07564936 0.06162125]]

 [[0.07140976 0.05680817]]

 [[0.04621605 0.03621913]]]
rmse: 0.0517
------------------------------------------------------------------
rmse mean: 0.0608
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.01
seed:0
[[[0.04223907 0.03340813]]

 [[0.10738603 0.09893605]]

 [[0.08217443 0.07400801]]

 [[0.04651504 0.03676897]]]
rmse: 0.0652
------------------------------------------------------------------
seed:1
[[[0.03291497 0.02527845]]

 [[0.10383185 0.09004621]]

 [[0.11038155 0.10257677]]

 [[0.04115237 0.0337789 ]]]
rmse: 0.0675
------------------------------------------------------------------
seed:2
[[[0.03542274 0.02847139]]

 [[0.11433443 0.1065439 ]]

 [[0.09524695 0.08574066]]

 [[0.03876551 0.03173158]]]
rmse: 0.0670
------------------------------------------------------------------
rmse mean: 0.0666
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.0
seed:0
[[[0.03742823 0.02952666]]

 [[0.11936543 0.10398346]]

 [[0.10989166 0.10414545]]

 [[0.04396324 0.03461398]]]
rmse: 0.0729
------------------------------------------------------------------
seed:1
[[[0.05185715 0.04032686]]

 [[0.11698319 0.09776901]]

 [[0.10691801 0.10074995]]

 [[0.04546743 0.03759608]]]
rmse: 0.0747
------------------------------------------------------------------
seed:2
[[[0.040619   0.03130755]]

 [[0.1243353  0.10200462]]

 [[0.06994017 0.05922791]]

 [[0.05660631 0.04438192]]]
rmse: 0.0661
------------------------------------------------------------------
rmse mean: 0.0712
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.0001
seed:0
[[[0.03739334 0.02950251]]

 [[0.11936193 0.10398276]]

 [[0.10993995 0.10420186]]

 [[0.04395174 0.03460219]]]
rmse: 0.0729
------------------------------------------------------------------
seed:1
[[[0.05185325 0.04032176]]

 [[0.11698144 0.09776863]]

 [[0.10690875 0.10073973]]

 [[0.04553839 0.03766151]]]
rmse: 0.0747
------------------------------------------------------------------
seed:2
[[[0.04058901 0.03128051]]

 [[0.1243312  0.10201351]]

 [[0.06995175 0.05923846]]

 [[0.05659813 0.04437692]]]
rmse: 0.0660
------------------------------------------------------------------
rmse mean: 0.0712
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.001
seed:0
[[[0.03708471 0.02930752]]

 [[0.11933151 0.10397979]]

 [[0.10973379 0.10402041]]

 [[0.04423679 0.03518833]]]
rmse: 0.0729
------------------------------------------------------------------
seed:1
[[[0.0518193  0.04027219]]

 [[0.11320256 0.09741248]]

 [[0.08609381 0.07913892]]

 [[0.04642051 0.0350769 ]]]
rmse: 0.0687
------------------------------------------------------------------
seed:2
[[[0.04029825 0.03107592]]

 [[0.11371896 0.10048914]]

 [[0.09249091 0.08245132]]

 [[0.04562972 0.03495162]]]
rmse: 0.0676
------------------------------------------------------------------
rmse mean: 0.0697
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.01
seed:0
[[[0.03470407 0.02779053]]

 [[0.11327435 0.09780644]]

 [[0.07100567 0.0612421 ]]

 [[0.04637073 0.03960176]]]
rmse: 0.0615
------------------------------------------------------------------
seed:1
[[[0.03697352 0.03122633]]

 [[0.14585215 0.11217251]]

 [[0.1252756  0.11709207]]

 [[0.04192058 0.03403158]]]
rmse: 0.0806
------------------------------------------------------------------
seed:2
[[[0.04170508 0.03177374]]

 [[0.11774204 0.10696926]]

 [[0.09750139 0.09022693]]

 [[0.0445959  0.03549327]]]
rmse: 0.0708
------------------------------------------------------------------
rmse mean: 0.0709
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.0
seed:0
[[[0.04275174 0.03442493]]

 [[0.11912285 0.10336419]]

 [[0.13525258 0.12850554]]

 [[0.04716561 0.03712383]]]
rmse: 0.0810
------------------------------------------------------------------
seed:1
[[[0.04245992 0.03347361]]

 [[0.11540291 0.09524468]]

 [[0.04762076 0.03768326]]

 [[0.05297016 0.04306307]]]
rmse: 0.0585
------------------------------------------------------------------
seed:2
[[[0.0614791  0.04948514]]

 [[0.12357736 0.10253127]]

 [[0.06617802 0.05783783]]

 [[0.04537448 0.03705549]]]
rmse: 0.0679
------------------------------------------------------------------
rmse mean: 0.0691
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.0001
seed:0
[[[0.04274023 0.034417  ]]

 [[0.11912811 0.1033693 ]]

 [[0.13524625 0.12850202]]

 [[0.04715235 0.03711698]]]
rmse: 0.0810
------------------------------------------------------------------
seed:1
[[[0.04244307 0.03346123]]

 [[0.11538406 0.09523536]]

 [[0.04760838 0.03767179]]

 [[0.05295143 0.04305104]]]
rmse: 0.0585
------------------------------------------------------------------
seed:2
[[[0.06144411 0.04945928]]

 [[0.12358327 0.10253775]]

 [[0.06616867 0.05782782]]

 [[0.04535468 0.03703955]]]
rmse: 0.0679
------------------------------------------------------------------
rmse mean: 0.0691
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.001
seed:0
[[[0.04263991 0.03436871]]

 [[0.12277257 0.10810781]]

 [[0.11637028 0.11195551]]

 [[0.04404631 0.03316686]]]
rmse: 0.0767
------------------------------------------------------------------
seed:1
[[[0.04231049 0.03336347]]

 [[0.11513697 0.09508631]]

 [[0.0475284  0.03759564]]

 [[0.04587954 0.03871815]]]
rmse: 0.0570
------------------------------------------------------------------
seed:2
[[[0.04286062 0.03293722]]

 [[0.10647794 0.08585205]]

 [[0.11832847 0.11000455]]

 [[0.03994306 0.0314803 ]]]
rmse: 0.0710
------------------------------------------------------------------
rmse mean: 0.0682
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.01
seed:0
[[[0.04202543 0.03420687]]

 [[0.1204838  0.11093772]]

 [[0.08286342 0.07029896]]

 [[0.04527937 0.03824861]]]
rmse: 0.0680
------------------------------------------------------------------
seed:1
[[[0.04223601 0.03332886]]

 [[0.10395203 0.08838228]]

 [[0.09515492 0.08540869]]

 [[0.04413264 0.03420275]]]
rmse: 0.0658
------------------------------------------------------------------
seed:2
[[[0.039294   0.03206775]]

 [[0.11959298 0.11298439]]

 [[0.10930978 0.10549036]]

 [[0.04492705 0.03572013]]]
rmse: 0.0749
------------------------------------------------------------------
rmse mean: 0.0696
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.0
seed:0
[[[0.04392244 0.03388444]]

 [[0.11555776 0.09968722]]

 [[0.08121519 0.07171135]]

 [[0.0478491  0.03941749]]]
rmse: 0.0667
------------------------------------------------------------------
seed:1
[[[0.03632576 0.02952351]]

 [[0.09115371 0.07463958]]

 [[0.12061327 0.10779113]]

 [[0.04614587 0.03640785]]]
rmse: 0.0678
------------------------------------------------------------------
seed:2
[[[0.04439992 0.03419793]]

 [[0.09699361 0.08171725]]

 [[0.08379164 0.07628489]]

 [[0.04277334 0.03504519]]]
rmse: 0.0619
------------------------------------------------------------------
rmse mean: 0.0655
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.0001
seed:0
[[[0.04392986 0.03389252]]

 [[0.11555715 0.09969052]]

 [[0.08118385 0.0716885 ]]

 [[0.04784846 0.03941675]]]
rmse: 0.0667
------------------------------------------------------------------
seed:1
[[[0.03627323 0.02946598]]

 [[0.091169   0.07465001]]

 [[0.12062542 0.10780924]]

 [[0.0406229  0.03324772]]]
rmse: 0.0667
------------------------------------------------------------------
seed:2
[[[0.04434785 0.03415257]]

 [[0.09697892 0.0817059 ]]

 [[0.08378142 0.07627278]]

 [[0.04278048 0.03504808]]]
rmse: 0.0619
------------------------------------------------------------------
rmse mean: 0.0651
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.001
seed:0
[[[0.04393345 0.03390052]]

 [[0.11555139 0.09972215]]

 [[0.08091305 0.07148291]]

 [[0.0397319  0.03323595]]]
rmse: 0.0648
------------------------------------------------------------------
seed:1
[[[0.0362411  0.02943909]]

 [[0.09148902 0.07491788]]

 [[0.10773217 0.10200964]]

 [[0.04812781 0.0350708 ]]]
rmse: 0.0656
------------------------------------------------------------------
seed:2
[[[0.04385662 0.03375615]]

 [[0.11333948 0.10201059]]

 [[0.10521249 0.09780741]]

 [[0.04651443 0.03659941]]]
rmse: 0.0724
------------------------------------------------------------------
rmse mean: 0.0676
===================================================================
K=32, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.01
seed:0
[[[0.04316891 0.03328673]]

 [[0.11520323 0.09957213]]

 [[0.07452394 0.06570975]]

 [[0.04258815 0.03594487]]]
rmse: 0.0637
------------------------------------------------------------------
seed:1
[[[0.03772023 0.02993249]]

 [[0.12321904 0.11263288]]

 [[0.11254505 0.10815798]]

 [[0.03928256 0.03245259]]]
rmse: 0.0745
------------------------------------------------------------------
seed:2
[[[0.03839284 0.02926316]]

 [[0.10913429 0.09633193]]

 [[0.11179143 0.10815876]]

 [[0.04068726 0.033327  ]]]
rmse: 0.0709
------------------------------------------------------------------
rmse mean: 0.0697
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.0
seed:0
[[[0.0371999  0.02994731]]

 [[0.10996634 0.09337514]]

 [[0.06469451 0.05353472]]

 [[0.04139463 0.03535929]]]
rmse: 0.0582
------------------------------------------------------------------
seed:1
[[[0.03921009 0.03051472]]

 [[0.11495768 0.0950727 ]]

 [[0.06608422 0.05242411]]

 [[0.0356121  0.02906253]]]
rmse: 0.0579
------------------------------------------------------------------
seed:2
[[[0.04340106 0.03339075]]

 [[0.15675635 0.11551828]]

 [[0.08593269 0.07074295]]

 [[0.04329851 0.0354964 ]]]
rmse: 0.0731
------------------------------------------------------------------
rmse mean: 0.0630
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.0001
seed:0
[[[0.03718707 0.02993729]]

 [[0.10997087 0.0933774 ]]

 [[0.06468839 0.0535304 ]]

 [[0.04138777 0.03536831]]]
rmse: 0.0582
------------------------------------------------------------------
seed:1
[[[0.03921744 0.03052047]]

 [[0.11495794 0.09507721]]

 [[0.06607458 0.05241232]]

 [[0.03561134 0.02906403]]]
rmse: 0.0579
------------------------------------------------------------------
seed:2
[[[0.04331674 0.03335215]]

 [[0.15659971 0.11537206]]

 [[0.08592953 0.07074018]]

 [[0.04328149 0.03548341]]]
rmse: 0.0730
------------------------------------------------------------------
rmse mean: 0.0630
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.001
seed:0
[[[0.03711815 0.0299027 ]]

 [[0.10976107 0.09504428]]

 [[0.11606413 0.10971523]]

 [[0.03971326 0.03297168]]]
rmse: 0.0713
------------------------------------------------------------------
seed:1
[[[0.03922296 0.0305155 ]]

 [[0.11489882 0.09505593]]

 [[0.1244948  0.11958429]]

 [[0.05014966 0.04299458]]]
rmse: 0.0771
------------------------------------------------------------------
seed:2
[[[0.04299065 0.03311972]]

 [[0.07957361 0.0596558 ]]

 [[0.11334924 0.10872377]]

 [[0.05124565 0.04293375]]]
rmse: 0.0664
------------------------------------------------------------------
rmse mean: 0.0716
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.01
seed:0
[[[0.03717164 0.03015291]]

 [[0.11712299 0.10509955]]

 [[0.09554941 0.08881328]]

 [[0.03614488 0.03072916]]]
rmse: 0.0676
------------------------------------------------------------------
seed:1
[[[0.03562903 0.02876583]]

 [[0.1180276  0.09290833]]

 [[0.10259434 0.09819089]]

 [[0.04345561 0.03547155]]]
rmse: 0.0694
------------------------------------------------------------------
seed:2
[[[0.04126492 0.03341621]]

 [[0.11309984 0.09705169]]

 [[0.10771336 0.10361385]]

 [[0.04287911 0.03586538]]]
rmse: 0.0719
------------------------------------------------------------------
rmse mean: 0.0696
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.0
seed:0
[[[0.04503032 0.03471804]]

 [[0.10731659 0.09345832]]

 [[0.11006449 0.10093009]]

 [[0.0425152  0.03328644]]]
rmse: 0.0709
------------------------------------------------------------------
seed:1
[[[0.03954663 0.0305291 ]]

 [[0.11943604 0.1033133 ]]

 [[0.11493884 0.1104656 ]]

 [[0.04262972 0.0354547 ]]]
rmse: 0.0745
------------------------------------------------------------------
seed:2
[[[0.03515027 0.02923332]]

 [[0.09680419 0.07416518]]

 [[0.09925313 0.09580443]]

 [[0.03866795 0.03223416]]]
rmse: 0.0627
------------------------------------------------------------------
rmse mean: 0.0694
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.0001
seed:0
[[[0.04499995 0.03469739]]

 [[0.10731574 0.09346117]]

 [[0.10999211 0.10085715]]

 [[0.04253748 0.03332675]]]
rmse: 0.0709
------------------------------------------------------------------
seed:1
[[[0.03961835 0.03065077]]

 [[0.11941331 0.10330358]]

 [[0.11495253 0.11049298]]

 [[0.04263001 0.0354542 ]]]
rmse: 0.0746
------------------------------------------------------------------
seed:2
[[[0.03511863 0.02920457]]

 [[0.0967846  0.0741509 ]]

 [[0.09924886 0.09580048]]

 [[0.03864468 0.03221391]]]
rmse: 0.0626
------------------------------------------------------------------
rmse mean: 0.0694
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.001
seed:0
[[[0.04464555 0.03443012]]

 [[0.10727724 0.09344472]]

 [[0.10946231 0.10032719]]

 [[0.04263672 0.03348965]]]
rmse: 0.0707
------------------------------------------------------------------
seed:1
[[[0.03997515 0.03131263]]

 [[0.11931749 0.10376883]]

 [[0.12189241 0.11646887]]

 [[0.04368392 0.03434693]]]
rmse: 0.0763
------------------------------------------------------------------
seed:2
[[[0.03485283 0.02895107]]

 [[0.10900374 0.08768724]]

 [[0.11449056 0.1075903 ]]

 [[0.03948648 0.03165446]]]
rmse: 0.0692
------------------------------------------------------------------
rmse mean: 0.0721
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.01
seed:0
[[[0.04298679 0.03322122]]

 [[0.11907899 0.10587912]]

 [[0.07600555 0.06910985]]

 [[0.04036146 0.03352424]]]
rmse: 0.0650
------------------------------------------------------------------
seed:1
[[[0.03716187 0.0288596 ]]

 [[0.11036925 0.09227519]]

 [[0.0913759  0.08134566]]

 [[0.04027071 0.03046048]]]
rmse: 0.0640
------------------------------------------------------------------
seed:2
[[[0.0343909  0.02804114]]

 [[0.10965439 0.08597724]]

 [[0.08086464 0.07194454]]

 [[0.03850055 0.03154053]]]
rmse: 0.0601
------------------------------------------------------------------
rmse mean: 0.0630
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.0
seed:0
[[[0.03996964 0.03191566]]

 [[0.0994734  0.08852261]]

 [[0.11506889 0.11143796]]

 [[0.05103985 0.04091364]]]
rmse: 0.0723
------------------------------------------------------------------
seed:1
[[[0.03852193 0.02972532]]

 [[0.15183392 0.1195002 ]]

 [[0.11728622 0.11186556]]

 [[0.04849572 0.03993602]]]
rmse: 0.0821
------------------------------------------------------------------
seed:2
[[[0.03640538 0.0279816 ]]

 [[0.12187751 0.10513382]]

 [[0.11812587 0.11486174]]

 [[0.04898875 0.03817142]]]
rmse: 0.0764
------------------------------------------------------------------
rmse mean: 0.0770
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.0001
seed:0
[[[0.039926   0.0318841 ]]

 [[0.09949837 0.08855462]]

 [[0.11506509 0.11143462]]

 [[0.05103991 0.04091291]]]
rmse: 0.0723
------------------------------------------------------------------
seed:1
[[[0.03850745 0.02971342]]

 [[0.1518657  0.11952406]]

 [[0.1172862  0.11186523]]

 [[0.04849379 0.03993553]]]
rmse: 0.0821
------------------------------------------------------------------
seed:2
[[[0.0363536  0.02794321]]

 [[0.12186638 0.10524159]]

 [[0.11811774 0.11485481]]

 [[0.04899477 0.03816944]]]
rmse: 0.0764
------------------------------------------------------------------
rmse mean: 0.0770
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.001
seed:0
[[[0.03962301 0.03163372]]

 [[0.09966919 0.08882195]]

 [[0.1150509  0.111423  ]]

 [[0.04247889 0.03380708]]]
rmse: 0.0703
------------------------------------------------------------------
seed:1
[[[0.03841993 0.02963395]]

 [[0.11362642 0.08616252]]

 [[0.10632248 0.09942426]]

 [[0.0410761  0.0340254 ]]]
rmse: 0.0686
------------------------------------------------------------------
seed:2
[[[0.03624854 0.02791528]]

 [[0.122246   0.10657304]]

 [[0.11805494 0.11480208]]

 [[0.03932236 0.03076104]]]
rmse: 0.0745
------------------------------------------------------------------
rmse mean: 0.0711
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.01
seed:0
[[[0.03793939 0.03034168]]

 [[0.10524649 0.09534049]]

 [[0.09331254 0.08505204]]

 [[0.03980616 0.03266011]]]
rmse: 0.0650
------------------------------------------------------------------
seed:1
[[[0.03751336 0.02890879]]

 [[0.11531985 0.08779229]]

 [[0.08295466 0.07493306]]

 [[0.04241277 0.0338197 ]]]
rmse: 0.0630
------------------------------------------------------------------
seed:2
[[[0.03399299 0.02644241]]

 [[0.11392244 0.09217531]]

 [[0.08084483 0.06951397]]

 [[0.04197593 0.03508919]]]
rmse: 0.0617
------------------------------------------------------------------
rmse mean: 0.0632
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.0
seed:0
[[[0.04702132 0.03850416]]

 [[0.11185781 0.08947213]]

 [[0.09949038 0.09440301]]

 [[0.03990603 0.02976829]]]
rmse: 0.0688
------------------------------------------------------------------
seed:1
[[[0.03564493 0.02706585]]

 [[0.11407401 0.09444429]]

 [[0.08619559 0.08063003]]

 [[0.04483439 0.03571379]]]
rmse: 0.0648
------------------------------------------------------------------
seed:2
[[[0.0514234  0.04239019]]

 [[0.11814312 0.09842255]]

 [[0.11572962 0.10992375]]

 [[0.04124922 0.03353626]]]
rmse: 0.0764
------------------------------------------------------------------
rmse mean: 0.0700
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.0001
seed:0
[[[0.04699671 0.03848409]]

 [[0.11186237 0.08948018]]

 [[0.09935112 0.0942893 ]]

 [[0.03990346 0.02976473]]]
rmse: 0.0688
------------------------------------------------------------------
seed:1
[[[0.03564262 0.02706316]]

 [[0.11405522 0.09443746]]

 [[0.08616028 0.08058984]]

 [[0.04483255 0.03571346]]]
rmse: 0.0648
------------------------------------------------------------------
seed:2
[[[0.05133117 0.04232284]]

 [[0.11812058 0.09840973]]

 [[0.1157265  0.10992562]]

 [[0.04123921 0.03352935]]]
rmse: 0.0763
------------------------------------------------------------------
rmse mean: 0.0700
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.001
seed:0
[[[0.04677394 0.03831042]]

 [[0.11189096 0.08952264]]

 [[0.09797931 0.09311604]]

 [[0.0399075  0.0297553 ]]]
rmse: 0.0684
------------------------------------------------------------------
seed:1
[[[0.03563485 0.02705462]]

 [[0.11384532 0.09433679]]

 [[0.08606827 0.08047502]]

 [[0.04360872 0.03523694]]]
rmse: 0.0645
------------------------------------------------------------------
seed:2
[[[0.05071782 0.04173304]]

 [[0.11787005 0.09821466]]

 [[0.11563798 0.10984789]]

 [[0.04138678 0.03387195]]]
rmse: 0.0762
------------------------------------------------------------------
rmse mean: 0.0697
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.01
seed:0
[[[0.03516846 0.02705968]]

 [[0.10201129 0.0871442 ]]

 [[0.09823614 0.09178468]]

 [[0.04282759 0.03530907]]]
rmse: 0.0649
------------------------------------------------------------------
seed:1
[[[0.03550254 0.02695424]]

 [[0.10626835 0.09151899]]

 [[0.1015357  0.09572043]]

 [[0.04236686 0.03554617]]]
rmse: 0.0669
------------------------------------------------------------------
seed:2
[[[0.03882406 0.03020892]]

 [[0.11632997 0.09069524]]

 [[0.0829184  0.07766281]]

 [[0.04260313 0.03539446]]]
rmse: 0.0643
------------------------------------------------------------------
rmse mean: 0.0654
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.0
seed:0
[[[0.04694678 0.03797604]]

 [[0.12401611 0.10307461]]

 [[0.12750906 0.12147446]]

 [[0.03530465 0.02866163]]]
rmse: 0.0781
------------------------------------------------------------------
seed:1
[[[0.03353204 0.02455835]]

 [[0.09659823 0.07318685]]

 [[0.12258948 0.11701728]]

 [[0.03711632 0.03091765]]]
rmse: 0.0669
------------------------------------------------------------------
seed:2
[[[0.03350129 0.02610846]]

 [[0.11588303 0.09883727]]

 [[0.11052213 0.10291927]]

 [[0.03666428 0.03065772]]]
rmse: 0.0694
------------------------------------------------------------------
rmse mean: 0.0715
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.0001
seed:0
[[[0.04693682 0.03795764]]

 [[0.12405132 0.10310003]]

 [[0.12748928 0.12146508]]

 [[0.03530375 0.02866105]]]
rmse: 0.0781
------------------------------------------------------------------
seed:1
[[[0.03351889 0.02454945]]

 [[0.09659437 0.07318425]]

 [[0.12254874 0.11697016]]

 [[0.0371141  0.03092556]]]
rmse: 0.0669
------------------------------------------------------------------
seed:2
[[[0.03344404 0.02609324]]

 [[0.11586344 0.09882706]]

 [[0.11046244 0.10290114]]

 [[0.03666444 0.03065437]]]
rmse: 0.0694
------------------------------------------------------------------
rmse mean: 0.0715
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.001
seed:0
[[[0.04699894 0.03797937]]

 [[0.12390157 0.10302342]]

 [[0.12457703 0.11918243]]

 [[0.03528564 0.02864332]]]
rmse: 0.0774
------------------------------------------------------------------
seed:1
[[[0.0333626  0.02445948]]

 [[0.10782488 0.08628784]]

 [[0.11606483 0.11137738]]

 [[0.03791066 0.02933301]]]
rmse: 0.0683
------------------------------------------------------------------
seed:2
[[[0.03299865 0.02591786]]

 [[0.11606445 0.10182513]]

 [[0.10006115 0.0951944 ]]

 [[0.04125002 0.03395649]]]
rmse: 0.0684
------------------------------------------------------------------
rmse mean: 0.0714
===================================================================
K=32, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.01
seed:0
[[[0.04768585 0.03848191]]

 [[0.11371139 0.09904019]]

 [[0.11069012 0.10479081]]

 [[0.03671066 0.02910287]]]
rmse: 0.0725
------------------------------------------------------------------
seed:1
[[[0.03384914 0.02400242]]

 [[0.11205937 0.09653928]]

 [[0.09339408 0.08810264]]

 [[0.03959396 0.03124186]]]
rmse: 0.0648
------------------------------------------------------------------
seed:2
[[[0.03430952 0.02609465]]

 [[0.10482396 0.08483894]]

 [[0.097309   0.0933374 ]]

 [[0.03878477 0.03171943]]]
rmse: 0.0639
------------------------------------------------------------------
rmse mean: 0.0671
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.0
seed:0
[[[0.03332006 0.02263053]]

 [[0.12380651 0.10273081]]

 [[0.11556788 0.11351498]]

 [[0.03801454 0.03076121]]]
rmse: 0.0725
------------------------------------------------------------------
seed:1
[[[0.03318524 0.02420733]]

 [[0.11345411 0.09730406]]

 [[0.1172332  0.1132053 ]]

 [[0.04234513 0.03509978]]]
rmse: 0.0720
------------------------------------------------------------------
seed:2
[[[0.05425386 0.04534778]]

 [[0.12573686 0.10361135]]

 [[0.11774773 0.11237631]]

 [[0.04370081 0.03565472]]]
rmse: 0.0798
------------------------------------------------------------------
rmse mean: 0.0748
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.0001
seed:0
[[[0.03331939 0.02263009]]

 [[0.1238064  0.10273074]]

 [[0.11556853 0.11351543]]

 [[0.03801452 0.03076131]]]
rmse: 0.0725
------------------------------------------------------------------
seed:1
[[[0.03318574 0.02420792]]

 [[0.11346419 0.09731589]]

 [[0.117233   0.11320516]]

 [[0.04234466 0.03509943]]]
rmse: 0.0720
------------------------------------------------------------------
seed:2
[[[0.05425063 0.04534488]]

 [[0.12573665 0.10361118]]

 [[0.11774759 0.11237641]]

 [[0.04370127 0.03565502]]]
rmse: 0.0798
------------------------------------------------------------------
rmse mean: 0.0748
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.001
seed:0
[[[0.03331291 0.02262567]]

 [[0.12379668 0.10272525]]

 [[0.11557478 0.11351999]]

 [[0.03801468 0.03076091]]]
rmse: 0.0725
------------------------------------------------------------------
seed:1
[[[0.03318283 0.02420497]]

 [[0.11341141 0.09730108]]

 [[0.11723381 0.11320544]]

 [[0.04234223 0.03509741]]]
rmse: 0.0720
------------------------------------------------------------------
seed:2
[[[0.05424673 0.04534183]]

 [[0.12573572 0.10361038]]

 [[0.11774478 0.11237482]]

 [[0.04370391 0.03565574]]]
rmse: 0.0798
------------------------------------------------------------------
rmse mean: 0.0748
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.01
seed:0
[[[0.03324825 0.02257762]]

 [[0.12372559 0.10268325]]

 [[0.11564107 0.11356987]]

 [[0.03802067 0.03076366]]]
rmse: 0.0725
------------------------------------------------------------------
seed:1
[[[0.03316028 0.02418323]]

 [[0.11349127 0.09734017]]

 [[0.11723936 0.11321065]]

 [[0.04228431 0.03505017]]]
rmse: 0.0720
------------------------------------------------------------------
seed:2
[[[0.05417176 0.04528237]]

 [[0.12574163 0.10361218]]

 [[0.11772409 0.11236707]]

 [[0.04368204 0.03561278]]]
rmse: 0.0798
------------------------------------------------------------------
rmse mean: 0.0748
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.0
seed:0
[[[0.03534208 0.02565486]]

 [[0.15861877 0.14039678]]

 [[0.12018842 0.11527515]]

 [[0.04097281 0.03402642]]]
rmse: 0.0838
------------------------------------------------------------------
seed:1
[[[0.02915161 0.02075872]]

 [[0.12408666 0.10263188]]

 [[0.11679793 0.11194924]]

 [[0.0385657  0.03097937]]]
rmse: 0.0719
------------------------------------------------------------------
seed:2
[[[0.03037309 0.02041467]]

 [[0.12040843 0.09969519]]

 [[0.11235758 0.11004776]]

 [[0.03443198 0.02777476]]]
rmse: 0.0694
------------------------------------------------------------------
rmse mean: 0.0750
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.0001
seed:0
[[[0.03534114 0.02565399]]

 [[0.04737773 0.0401658 ]]

 [[0.1164617  0.11420325]]

 [[0.03440668 0.02768921]]]
rmse: 0.0552
------------------------------------------------------------------
seed:1
[[[0.02915036 0.02075788]]

 [[0.12408771 0.10263325]]

 [[0.11679755 0.11194897]]

 [[0.03856534 0.03098106]]]
rmse: 0.0719
------------------------------------------------------------------
seed:2
[[[0.03037215 0.0204149 ]]

 [[0.12040756 0.09969449]]

 [[0.11235811 0.1100486 ]]

 [[0.03443208 0.02777481]]]
rmse: 0.0694
------------------------------------------------------------------
rmse mean: 0.0655
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.001
seed:0
[[[0.03533286 0.02564638]]

 [[0.05392549 0.0431496 ]]

 [[0.11667916 0.11425673]]

 [[0.0405618  0.03373372]]]
rmse: 0.0579
------------------------------------------------------------------
seed:1
[[[0.02915047 0.02075822]]

 [[0.12408925 0.10263338]]

 [[0.11679648 0.11194916]]

 [[0.03855624 0.0309719 ]]]
rmse: 0.0719
------------------------------------------------------------------
seed:2
[[[0.03036534 0.02041692]]

 [[0.12041019 0.09969841]]

 [[0.11236211 0.11005279]]

 [[0.03443351 0.02777483]]]
rmse: 0.0694
------------------------------------------------------------------
rmse mean: 0.0664
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.01
seed:0
[[[0.03526052 0.02558   ]]

 [[0.12046383 0.09967906]]

 [[0.11715429 0.11379959]]

 [[0.03981126 0.03304005]]]
rmse: 0.0731
------------------------------------------------------------------
seed:1
[[[0.02912743 0.0207504 ]]

 [[0.12401554 0.10259403]]

 [[0.11677415 0.11194655]]

 [[0.03845251 0.03084138]]]
rmse: 0.0718
------------------------------------------------------------------
seed:2
[[[0.03187646 0.02183189]]

 [[0.1229545  0.10203114]]

 [[0.11832476 0.11318108]]

 [[0.03634804 0.02994838]]]
rmse: 0.0721
------------------------------------------------------------------
rmse mean: 0.0723
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.0
seed:0
[[[0.0294286  0.02190402]]

 [[0.12400354 0.10165532]]

 [[0.12272408 0.11654356]]

 [[0.03685075 0.02966131]]]
rmse: 0.0728
------------------------------------------------------------------
seed:1
[[[0.0405919  0.02921619]]

 [[0.12502532 0.10312009]]

 [[0.1181711  0.11442086]]

 [[0.03770066 0.03023736]]]
rmse: 0.0748
------------------------------------------------------------------
seed:2
[[[0.03811447 0.02660118]]

 [[0.121238   0.10114715]]

 [[0.12049666 0.11556313]]

 [[0.03558346 0.02766479]]]
rmse: 0.0733
------------------------------------------------------------------
rmse mean: 0.0737
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.0001
seed:0
[[[0.02942841 0.02190383]]

 [[0.12399681 0.1016537 ]]

 [[0.12272435 0.11654429]]

 [[0.05459235 0.04320117]]]
rmse: 0.0768
------------------------------------------------------------------
seed:1
[[[0.0405941  0.02922082]]

 [[0.12502451 0.10311945]]

 [[0.11817093 0.11442053]]

 [[0.03769981 0.03023671]]]
rmse: 0.0748
------------------------------------------------------------------
seed:2
[[[0.03811246 0.02659883]]

 [[0.12123753 0.10114672]]

 [[0.12049792 0.1155646 ]]

 [[0.03558332 0.0276647 ]]]
rmse: 0.0733
------------------------------------------------------------------
rmse mean: 0.0750
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.001
seed:0
[[[0.02941926 0.02190012]]

 [[0.12392279 0.10164137]]

 [[0.1173564  0.11441444]]

 [[0.04157084 0.03465116]]]
rmse: 0.0731
------------------------------------------------------------------
seed:1
[[[0.04060198 0.02926378]]

 [[0.12505339 0.10313715]]

 [[0.11813234 0.11440112]]

 [[0.03770903 0.03024462]]]
rmse: 0.0748
------------------------------------------------------------------
seed:2
[[[0.03810721 0.02659445]]

 [[0.12124065 0.10114845]]

 [[0.1204958  0.11556406]]

 [[0.03558547 0.02766631]]]
rmse: 0.0733
------------------------------------------------------------------
rmse mean: 0.0737
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.01
seed:0
[[[0.02946129 0.02194745]]

 [[0.12418536 0.10191929]]

 [[0.12275244 0.11654856]]

 [[0.08406701 0.0754165 ]]]
rmse: 0.0845
------------------------------------------------------------------
seed:1
[[[0.02799332 0.02191107]]

 [[0.1169079  0.10129464]]

 [[0.11858284 0.11370497]]

 [[0.04400819 0.03615326]]]
rmse: 0.0726
------------------------------------------------------------------
seed:2
[[[0.03806993 0.02656535]]

 [[0.12126021 0.10115822]]

 [[0.12048764 0.1155645 ]]

 [[0.03560607 0.02768277]]]
rmse: 0.0733
------------------------------------------------------------------
rmse mean: 0.0768
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.0
seed:0
[[[0.02877828 0.02076705]]

 [[0.11959501 0.10069795]]

 [[0.11620682 0.11407502]]

 [[0.03330108 0.02680102]]]
rmse: 0.0700
------------------------------------------------------------------
seed:1
[[[0.02690477 0.0194655 ]]

 [[0.12245217 0.10184842]]

 [[0.11935607 0.11211528]]

 [[0.03825797 0.02997218]]]
rmse: 0.0713
------------------------------------------------------------------
seed:2
[[[0.0299003  0.02144678]]

 [[0.0837366  0.06172886]]

 [[0.10778081 0.10641922]]

 [[0.04022387 0.03260744]]]
rmse: 0.0605
------------------------------------------------------------------
rmse mean: 0.0673
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.0001
seed:0
[[[0.02878104 0.02076662]]

 [[0.11938488 0.1005969 ]]

 [[0.11621012 0.11407747]]

 [[0.03330387 0.02680721]]]
rmse: 0.0700
------------------------------------------------------------------
seed:1
[[[0.0269048  0.01946519]]

 [[0.12245114 0.10184787]]

 [[0.11934827 0.1121097 ]]

 [[0.03825643 0.02997117]]]
rmse: 0.0713
------------------------------------------------------------------
seed:2
[[[0.02989827 0.0214451 ]]

 [[0.10258104 0.08029122]]

 [[0.10778301 0.10642134]]

 [[0.04022348 0.03260725]]]
rmse: 0.0652
------------------------------------------------------------------
rmse mean: 0.0688
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.001
seed:0
[[[0.02877637 0.02076811]]

 [[0.11987507 0.10081864]]

 [[0.11622988 0.11409067]]

 [[0.03332564 0.02682773]]]
rmse: 0.0701
------------------------------------------------------------------
seed:1
[[[0.02691209 0.01947218]]

 [[0.12244313 0.10184332]]

 [[0.11935656 0.1121143 ]]

 [[0.03824477 0.02996104]]]
rmse: 0.0713
------------------------------------------------------------------
seed:2
[[[0.02990001 0.02144575]]

 [[0.12559582 0.10120567]]

 [[0.11982053 0.11455728]]

 [[0.03890149 0.0315395 ]]]
rmse: 0.0729
------------------------------------------------------------------
rmse mean: 0.0714
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.01
seed:0
[[[0.03147262 0.02313356]]

 [[0.1497948  0.13285032]]

 [[0.11912268 0.11471745]]

 [[0.03645958 0.03019969]]]
rmse: 0.0797
------------------------------------------------------------------
seed:1
[[[0.02690515 0.01946839]]

 [[0.1224286  0.10183743]]

 [[0.11937761 0.11211602]]

 [[0.03823884 0.0299538 ]]]
rmse: 0.0713
------------------------------------------------------------------
seed:2
[[[0.02988816 0.02143382]]

 [[0.19580527 0.17158456]]

 [[0.09223147 0.07841167]]

 [[0.03397858 0.02713959]]]
rmse: 0.0813
------------------------------------------------------------------
rmse mean: 0.0774
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.0
seed:0
[[[0.02817362 0.02044576]]

 [[0.1072017  0.08742965]]

 [[0.10970714 0.10798444]]

 [[0.03365299 0.02833721]]]
rmse: 0.0654
------------------------------------------------------------------
seed:1
[[[0.05814715 0.05157099]]

 [[0.11605046 0.09896002]]

 [[0.11986985 0.11474182]]

 [[0.03495282 0.03013161]]]
rmse: 0.0781
------------------------------------------------------------------
seed:2
[[[0.02736449 0.01783038]]

 [[0.11073005 0.09792315]]

 [[0.1119603  0.11026098]]

 [[0.03439724 0.02885342]]]
rmse: 0.0674
------------------------------------------------------------------
rmse mean: 0.0703
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.0001
seed:0
[[[0.02817119 0.02044474]]

 [[0.06667215 0.04952024]]

 [[0.09900787 0.09633028]]

 [[0.03711148 0.02795835]]]
rmse: 0.0532
------------------------------------------------------------------
seed:1
[[[0.05835316 0.05176367]]

 [[0.11605355 0.0989618 ]]

 [[0.11984682 0.11473248]]

 [[0.03495292 0.03013207]]]
rmse: 0.0781
------------------------------------------------------------------
seed:2
[[[0.02736408 0.01783095]]

 [[0.11064845 0.09793233]]

 [[0.1119612  0.11026173]]

 [[0.03439767 0.02885582]]]
rmse: 0.0674
------------------------------------------------------------------
rmse mean: 0.0662
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.001
seed:0
[[[0.02810226 0.0204325 ]]

 [[0.0526023  0.03963614]]

 [[0.10392367 0.10185474]]

 [[0.03430149 0.02748027]]]
rmse: 0.0510
------------------------------------------------------------------
seed:1
[[[0.04720682 0.04091436]]

 [[0.11584211 0.09885655]]

 [[0.11994242 0.11477637]]

 [[0.03495315 0.03013261]]]
rmse: 0.0753
------------------------------------------------------------------
seed:2
[[[0.02736029 0.01784266]]

 [[0.11051619 0.09795723]]

 [[0.11194526 0.11024844]]

 [[0.03439819 0.02887607]]]
rmse: 0.0674
------------------------------------------------------------------
rmse mean: 0.0646
===================================================================
K=32, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.01
seed:0
[[[0.02820679 0.02054257]]

 [[0.07553114 0.065873  ]]

 [[0.1032738  0.10010952]]

 [[0.03480659 0.02790995]]]
rmse: 0.0570
------------------------------------------------------------------
seed:1
[[[0.03290271 0.02336675]]

 [[0.11500637 0.09624963]]

 [[0.04712472 0.03599202]]

 [[0.0340481  0.02830189]]]
rmse: 0.0516
------------------------------------------------------------------
seed:2
[[[0.02735501 0.01853231]]

 [[0.11055308 0.09797399]]

 [[0.11194146 0.11024717]]

 [[0.03441445 0.02905473]]]
rmse: 0.0675
------------------------------------------------------------------
rmse mean: 0.0587
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.0
seed:0
[[[0.08009432 0.07079498]]

 [[0.11452797 0.1000851 ]]

 [[0.10718856 0.1055583 ]]

 [[0.03457296 0.0290619 ]]]
rmse: 0.0802
------------------------------------------------------------------
seed:1
[[[0.03306871 0.02311317]]

 [[0.10969275 0.10160651]]

 [[0.03469823 0.02087735]]

 [[0.03612081 0.02967844]]]
rmse: 0.0486
------------------------------------------------------------------
seed:2
[[[0.03866585 0.02531269]]

 [[0.10420908 0.09114824]]

 [[0.07375155 0.05973231]]

 [[0.03653365 0.02979833]]]
rmse: 0.0574
------------------------------------------------------------------
rmse mean: 0.0621
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.0001
seed:0
[[[0.09134086 0.0802961 ]]

 [[0.11774151 0.10286675]]

 [[0.19066122 0.18881654]]

 [[0.036723   0.02975325]]]
rmse: 0.1048
------------------------------------------------------------------
seed:1
[[[0.12286962 0.10630318]]

 [[0.11636522 0.09505744]]

 [[0.0775115  0.07352999]]

 [[0.10848664 0.09886083]]]
rmse: 0.0999
------------------------------------------------------------------
seed:2
[[[0.03866646 0.02531352]]

 [[0.10423837 0.09116065]]

 [[0.07354805 0.0595899 ]]

 [[0.03653116 0.02979522]]]
rmse: 0.0574
------------------------------------------------------------------
rmse mean: 0.0873
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.001
seed:0
[[[0.11969209 0.10496086]]

 [[0.09317127 0.0832276 ]]

 [[0.16200696 0.16028441]]

 [[0.03672625 0.02975666]]]
rmse: 0.0987
------------------------------------------------------------------
seed:1
[[[0.05332553 0.04355257]]

 [[0.12358263 0.10785988]]

 [[0.07971507 0.07227592]]

 [[0.05662959 0.04870977]]]
rmse: 0.0732
------------------------------------------------------------------
seed:2
[[[0.03872135 0.02539771]]

 [[0.10421489 0.09115757]]

 [[0.07111773 0.05804274]]

 [[0.03627686 0.02941672]]]
rmse: 0.0568
------------------------------------------------------------------
rmse mean: 0.0762
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.01
seed:0
[[[0.08060314 0.07648632]]

 [[0.11074932 0.09553746]]

 [[0.11465498 0.11241772]]

 [[0.04083381 0.03239861]]]
rmse: 0.0830
------------------------------------------------------------------
seed:1
[[[0.03181355 0.02205684]]

 [[0.14670896 0.11902155]]

 [[0.07694172 0.06443237]]

 [[0.03269176 0.027233  ]]]
rmse: 0.0651
------------------------------------------------------------------
seed:2
[[[0.03867183 0.02532355]]

 [[0.10416591 0.09117175]]

 [[0.0747671  0.06042143]]

 [[0.03618965 0.02929296]]]
rmse: 0.0575
------------------------------------------------------------------
rmse mean: 0.0685
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.0
seed:0
[[[0.03530501 0.02862405]]

 [[0.17350142 0.14402984]]

 [[0.10189856 0.08602589]]

 [[0.04118405 0.03251773]]]
rmse: 0.0804
------------------------------------------------------------------
seed:1
[[[0.04994913 0.04018234]]

 [[0.10821003 0.09531913]]

 [[0.1241356  0.11550489]]

 [[0.06668689 0.05906865]]]
rmse: 0.0824
------------------------------------------------------------------
seed:2
[[[0.06707933 0.05903757]]

 [[0.12683107 0.10402174]]

 [[0.0881853  0.07325087]]

 [[0.03629798 0.02865682]]]
rmse: 0.0729
------------------------------------------------------------------
rmse mean: 0.0786
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.0001
seed:0
[[[0.02866846 0.0198248 ]]

 [[0.1243967  0.1072927 ]]

 [[0.1240984  0.1160371 ]]

 [[0.04291282 0.031288  ]]]
rmse: 0.0743
------------------------------------------------------------------
seed:1
[[[0.04996742 0.04020616]]

 [[0.09423426 0.07496568]]

 [[0.12164598 0.11893146]]

 [[0.04285336 0.034165  ]]]
rmse: 0.0721
------------------------------------------------------------------
seed:2
[[[0.0474362  0.03473137]]

 [[0.12708144 0.10419173]]

 [[0.11880305 0.11362185]]

 [[0.03658476 0.02994041]]]
rmse: 0.0765
------------------------------------------------------------------
rmse mean: 0.0743
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.001
seed:0
[[[0.09474565 0.08084509]]

 [[0.09413147 0.07768281]]

 [[0.12889443 0.12000594]]

 [[0.03410907 0.02834052]]]
rmse: 0.0823
------------------------------------------------------------------
seed:1
[[[0.05005546 0.04032507]]

 [[0.11157036 0.09092662]]

 [[0.09676899 0.09484289]]

 [[0.03086497 0.02535631]]]
rmse: 0.0676
------------------------------------------------------------------
seed:2
[[[0.03127372 0.02511563]]

 [[0.12599823 0.10358919]]

 [[0.05623695 0.04643059]]

 [[0.08557419 0.07672727]]]
rmse: 0.0689
------------------------------------------------------------------
rmse mean: 0.0729
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.01
seed:0
[[[0.06170817 0.0556657 ]]

 [[0.07717894 0.05718762]]

 [[0.12921986 0.12055432]]

 [[0.03409909 0.02828678]]]
rmse: 0.0705
------------------------------------------------------------------
seed:1
[[[0.04472121 0.03748534]]

 [[0.17117065 0.14181999]]

 [[0.13078629 0.10186154]]

 [[0.09134565 0.07961371]]]
rmse: 0.0999
------------------------------------------------------------------
seed:2
[[[0.05631064 0.04935311]]

 [[0.04222782 0.03311883]]

 [[0.11082014 0.10902827]]

 [[0.06862355 0.06059818]]]
rmse: 0.0663
------------------------------------------------------------------
rmse mean: 0.0789
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.0
seed:0
[[[0.04003574 0.03444896]]

 [[0.0614964  0.05085729]]

 [[0.12417452 0.11320244]]

 [[0.04697352 0.03414627]]]
rmse: 0.0632
------------------------------------------------------------------
seed:1
[[[0.03951005 0.02771507]]

 [[0.13244452 0.10368577]]

 [[0.14244656 0.138669  ]]

 [[0.09075773 0.07906258]]]
rmse: 0.0943
------------------------------------------------------------------
seed:2
[[[0.04192694 0.03399796]]

 [[0.11212541 0.09703513]]

 [[0.15122043 0.11756265]]

 [[0.04278536 0.03282891]]]
rmse: 0.0787
------------------------------------------------------------------
rmse mean: 0.0787
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.0001
seed:0
[[[0.04333896 0.03806139]]

 [[0.06155912 0.05092739]]

 [[0.12414403 0.11316832]]

 [[0.03496756 0.02820717]]]
rmse: 0.0618
------------------------------------------------------------------
seed:1
[[[0.0387259  0.02886737]]

 [[0.16975712 0.14020933]]

 [[0.15177016 0.11776419]]

 [[0.09166225 0.0799113 ]]]
rmse: 0.1023
------------------------------------------------------------------
seed:2
[[[0.02870476 0.02370987]]

 [[0.11448    0.09480311]]

 [[0.1098233  0.10783695]]

 [[0.03879107 0.03180185]]]
rmse: 0.0687
------------------------------------------------------------------
rmse mean: 0.0776
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.001
seed:0
[[[0.03650796 0.03024951]]

 [[0.09057039 0.08567305]]

 [[0.11225922 0.11044945]]

 [[0.05711131 0.04022641]]]
rmse: 0.0704
------------------------------------------------------------------
seed:1
[[[0.06736514 0.0575497 ]]

 [[0.115891   0.09750134]]

 [[0.1517536  0.11778703]]

 [[0.04916468 0.03446364]]]
rmse: 0.0864
------------------------------------------------------------------
seed:2
[[[0.02738612 0.02189115]]

 [[0.17405096 0.14464659]]

 [[0.15279608 0.11804265]]

 [[0.04097965 0.03497726]]]
rmse: 0.0893
------------------------------------------------------------------
rmse mean: 0.0821
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.01
seed:0
[[[0.02834214 0.02101403]]

 [[0.07479319 0.05992717]]

 [[0.08780156 0.0854498 ]]

 [[0.09026229 0.07860035]]]
rmse: 0.0658
------------------------------------------------------------------
seed:1
[[[0.06612396 0.06215327]]

 [[0.25720677 0.21634269]]

 [[0.04922816 0.04215055]]

 [[0.11059687 0.10135721]]]
rmse: 0.1131
------------------------------------------------------------------
seed:2
[[[0.05639933 0.04710979]]

 [[0.09822803 0.07680245]]

 [[0.03382255 0.02136216]]

 [[0.05504658 0.04915212]]]
rmse: 0.0547
------------------------------------------------------------------
rmse mean: 0.0779
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.0
seed:0
[[[0.03888424 0.02889008]]

 [[0.0771687  0.06780492]]

 [[0.15228963 0.11792239]]

 [[0.08919584 0.0775648 ]]]
rmse: 0.0812
------------------------------------------------------------------
seed:1
[[[0.12820633 0.11107891]]

 [[0.17253452 0.14305571]]

 [[0.14452199 0.11155749]]

 [[0.09155542 0.07981485]]]
rmse: 0.1228
------------------------------------------------------------------
seed:2
[[[0.36424027 0.33847252]]

 [[0.17374214 0.14431627]]

 [[0.15258467 0.1180226 ]]

 [[0.03756397 0.03018649]]]
rmse: 0.1699
------------------------------------------------------------------
rmse mean: 0.1246
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.0001
seed:0
[[[0.02454131 0.01661926]]

 [[0.1653425  0.13659318]]

 [[0.15245204 0.11795057]]

 [[0.09159487 0.07984776]]]
rmse: 0.0981
------------------------------------------------------------------
seed:1
[[[0.12756711 0.11055022]]

 [[0.17253319 0.1430543 ]]

 [[0.14351693 0.1107669 ]]

 [[0.09162938 0.07988395]]]
rmse: 0.1224
------------------------------------------------------------------
seed:2
[[[0.15465344 0.13401143]]

 [[0.17374919 0.14432283]]

 [[0.03565775 0.02664563]]

 [[0.09073911 0.07900813]]]
rmse: 0.1048
------------------------------------------------------------------
rmse mean: 0.1085
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.001
seed:0
[[[0.03117419 0.02382355]]

 [[0.17410449 0.14469369]]

 [[0.08741619 0.07777162]]

 [[0.07438506 0.06364761]]]
rmse: 0.0846
------------------------------------------------------------------
seed:1
[[[0.12980129 0.11230797]]

 [[0.17253696 0.14305821]]

 [[0.14440798 0.11146815]]

 [[0.09159023 0.07984784]]]
rmse: 0.1231
------------------------------------------------------------------
seed:2
[[[0.13515087 0.11657122]]

 [[0.17372428 0.14429963]]

 [[0.15164693 0.11777707]]

 [[0.08653965 0.07501115]]]
rmse: 0.1251
------------------------------------------------------------------
rmse mean: 0.1109
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.01
seed:0
[[[0.04006864 0.03071478]]

 [[0.17159768 0.14211101]]

 [[0.12504466 0.09786275]]

 [[0.08904637 0.07749885]]]
rmse: 0.0967
------------------------------------------------------------------
seed:1
[[[0.1239307  0.10753609]]

 [[0.17255341 0.14307481]]

 [[0.1445406  0.11157212]]

 [[0.83478375 0.82997551]]]
rmse: 0.3085
------------------------------------------------------------------
seed:2
[[[0.13316015 0.11498187]]

 [[0.17371483 0.14429084]]

 [[0.1413848  0.11628676]]

 [[0.08820611 0.076591  ]]]
rmse: 0.1236
------------------------------------------------------------------
rmse mean: 0.1763
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.0
seed:0
[[[0.13492878 0.11650224]]

 [[1.12318337 1.11368448]]

 [[0.15259744 0.1180253 ]]

 [[0.24489741 0.22875491]]]
rmse: 0.4041
------------------------------------------------------------------
seed:1
[[[0.13532018 0.11680274]]

 [[0.17430437 0.14489641]]

 [[0.07441481 0.06608785]]

 [[0.09162731 0.079879  ]]]
rmse: 0.1104
------------------------------------------------------------------
seed:2
[[[0.13538611 0.11685679]]

 [[0.29826643 0.26565148]]

 [[0.15252598 0.11800909]]

 [[0.09138393 0.07965243]]]
rmse: 0.1572
------------------------------------------------------------------
rmse mean: 0.2239
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.0001
seed:0
[[[0.13536913 0.11685137]]

 [[0.81250973 0.79933488]]

 [[0.1525975  0.11802537]]

 [[1.24522313 1.2445803 ]]]
rmse: 0.5781
------------------------------------------------------------------
seed:1
[[[0.135322   0.11680414]]

 [[0.17687858 0.14726829]]

 [[0.15234107 0.11795982]]

 [[0.09166154 0.07991387]]]
rmse: 0.1273
------------------------------------------------------------------
seed:2
[[[0.13536181 0.11669613]]

 [[0.15038517 0.12341251]]

 [[0.15249813 0.11799876]]

 [[0.09672982 0.08569219]]]
rmse: 0.1223
------------------------------------------------------------------
rmse mean: 0.2759
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.001
seed:0
[[[0.13530535 0.11683623]]

 [[0.15539487 0.12452674]]

 [[0.15259528 0.11802525]]

 [[0.17435326 0.15742939]]]
rmse: 0.1418
------------------------------------------------------------------
seed:1
[[[0.13532208 0.11680443]]

 [[0.19698206 0.16629249]]

 [[0.15258842 0.11802353]]

 [[0.09169007 0.07994089]]]
rmse: 0.1322
------------------------------------------------------------------
seed:2
[[[0.13530978 0.11678991]]

 [[0.16012094 0.13210414]]

 [[1.22893349 1.22770396]]

 [[0.09156104 0.07981838]]]
rmse: 0.3965
------------------------------------------------------------------
rmse mean: 0.2235
===================================================================
K=32, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.01
seed:0
[[[0.13523804 0.11673454]]

 [[0.51998594 0.49916238]]

 [[0.15259828 0.11802603]]

 [[0.09019319 0.07854182]]]
rmse: 0.2138
------------------------------------------------------------------
seed:1
[[[0.13532256 0.11680472]]

 [[0.17407283 0.1446626 ]]

 [[0.15242786 0.11797709]]

 [[0.05502311 0.04560162]]]
rmse: 0.1177
------------------------------------------------------------------
seed:2
[[[0.11936952 0.10326077]]

 [[0.15010267 0.12103065]]

 [[0.15249209 0.11799785]]

 [[0.0432306  0.03638312]]]
rmse: 0.1055
------------------------------------------------------------------
rmse mean: 0.1457
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.0
seed:0
[[[0.05098472 0.04276758]]

 [[0.09600758 0.0794247 ]]

 [[0.06755043 0.05467814]]

 [[0.12221486 0.11042029]]]
rmse: 0.0780
------------------------------------------------------------------
seed:1
[[[0.05877801 0.04910229]]

 [[0.09818773 0.08178811]]

 [[0.11026765 0.09767024]]

 [[0.04167184 0.03336008]]]
rmse: 0.0714
------------------------------------------------------------------
seed:2
[[[0.0442943  0.03412713]]

 [[0.11015845 0.09607128]]

 [[0.09800475 0.09127957]]

 [[0.04501422 0.03581753]]]
rmse: 0.0693
------------------------------------------------------------------
rmse mean: 0.0729
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.0001
seed:0
[[[0.05077751 0.04251366]]

 [[0.09598106 0.07939761]]

 [[0.06752955 0.05466063]]

 [[0.0534813  0.04217609]]]
rmse: 0.0608
------------------------------------------------------------------
seed:1
[[[0.05875085 0.04907672]]

 [[0.0981793  0.08178157]]

 [[0.11022817 0.0976431 ]]

 [[0.04166632 0.03335171]]]
rmse: 0.0713
------------------------------------------------------------------
seed:2
[[[0.04427213 0.03410651]]

 [[0.1101698  0.09608272]]

 [[0.09799347 0.09126954]]

 [[0.0450017  0.03580853]]]
rmse: 0.0693
------------------------------------------------------------------
rmse mean: 0.0672
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.001
seed:0
[[[0.05066432 0.04242827]]

 [[0.09594432 0.07936853]]

 [[0.06737565 0.054531  ]]

 [[0.05348851 0.04217402]]]
rmse: 0.0607
------------------------------------------------------------------
seed:1
[[[0.05853626 0.04887561]]

 [[0.11985637 0.10807955]]

 [[0.0974224  0.09130866]]

 [[0.04723226 0.03925207]]]
rmse: 0.0763
------------------------------------------------------------------
seed:2
[[[0.0440651  0.03391645]]

 [[0.11021944 0.09614426]]

 [[0.0978677  0.09115889]]

 [[0.04500538 0.03582198]]]
rmse: 0.0693
------------------------------------------------------------------
rmse mean: 0.0688
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=16, alpha=0.01
seed:0
[[[0.03685799 0.03019501]]

 [[0.11386028 0.10293835]]

 [[0.08821089 0.07970896]]

 [[0.04038552 0.03335394]]]
rmse: 0.0657
------------------------------------------------------------------
seed:1
[[[0.03815738 0.03120363]]

 [[0.11862058 0.10823146]]

 [[0.10527133 0.10008521]]

 [[0.04778707 0.03833509]]]
rmse: 0.0735
------------------------------------------------------------------
seed:2
[[[0.03948283 0.03135297]]

 [[0.1198187  0.10599119]]

 [[0.08937806 0.07994836]]

 [[0.04135233 0.03248522]]]
rmse: 0.0675
------------------------------------------------------------------
rmse mean: 0.0689
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.0
seed:0
[[[0.05169753 0.04234459]]

 [[0.10595942 0.0876955 ]]

 [[0.10185956 0.09509583]]

 [[0.04532444 0.03547854]]]
rmse: 0.0707
------------------------------------------------------------------
seed:1
[[[0.05262353 0.04255119]]

 [[0.07360899 0.05603982]]

 [[0.12207549 0.11193649]]

 [[0.04645639 0.03804143]]]
rmse: 0.0679
------------------------------------------------------------------
seed:2
[[[0.04047459 0.02940543]]

 [[0.10403363 0.08123336]]

 [[0.09049191 0.08296065]]

 [[0.04515785 0.03720792]]]
rmse: 0.0639
------------------------------------------------------------------
rmse mean: 0.0675
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.0001
seed:0
[[[0.05163353 0.04230072]]

 [[0.10594974 0.08769937]]

 [[0.10184491 0.09508687]]

 [[0.04532749 0.03549105]]]
rmse: 0.0707
------------------------------------------------------------------
seed:1
[[[0.05260909 0.04253976]]

 [[0.07360199 0.05603402]]

 [[0.12205247 0.11194445]]

 [[0.04643255 0.03802264]]]
rmse: 0.0679
------------------------------------------------------------------
seed:2
[[[0.04043092 0.02934432]]

 [[0.10402718 0.08122648]]

 [[0.09047966 0.08294699]]

 [[0.04514254 0.03719252]]]
rmse: 0.0638
------------------------------------------------------------------
rmse mean: 0.0675
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.001
seed:0
[[[0.05118481 0.04197493]]

 [[0.10585027 0.08769   ]]

 [[0.10489658 0.10051587]]

 [[0.04678615 0.03746708]]]
rmse: 0.0720
------------------------------------------------------------------
seed:1
[[[0.05258218 0.04250856]]

 [[0.11552752 0.09840474]]

 [[0.11283921 0.10566336]]

 [[0.04126516 0.03510692]]]
rmse: 0.0755
------------------------------------------------------------------
seed:2
[[[0.0398753  0.0287567 ]]

 [[0.10597388 0.08320184]]

 [[0.10302084 0.0961635 ]]

 [[0.03767741 0.03107939]]]
rmse: 0.0657
------------------------------------------------------------------
rmse mean: 0.0711
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=32, alpha=0.01
seed:0
[[[0.04279159 0.03246049]]

 [[0.11466752 0.10306441]]

 [[0.10572639 0.10074272]]

 [[0.04202993 0.03341409]]]
rmse: 0.0719
------------------------------------------------------------------
seed:1
[[[0.04324279 0.03480348]]

 [[0.11323961 0.0995365 ]]

 [[0.08086208 0.0713449 ]]

 [[0.04783617 0.03538729]]]
rmse: 0.0658
------------------------------------------------------------------
seed:2
[[[0.0464385  0.03468574]]

 [[0.10774911 0.09740442]]

 [[0.06992424 0.06447069]]

 [[0.04620434 0.03823838]]]
rmse: 0.0631
------------------------------------------------------------------
rmse mean: 0.0669
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.0
seed:0
[[[0.03670643 0.02819379]]

 [[0.10988656 0.09719897]]

 [[0.09300626 0.08403473]]

 [[0.05536105 0.04878742]]]
rmse: 0.0691
------------------------------------------------------------------
seed:1
[[[0.04073223 0.03032508]]

 [[0.11945124 0.09799228]]

 [[0.09127979 0.08256678]]

 [[0.04734036 0.03753061]]]
rmse: 0.0684
------------------------------------------------------------------
seed:2
[[[0.03548434 0.02671929]]

 [[0.10366142 0.08358867]]

 [[0.08818144 0.08350187]]

 [[0.04796443 0.03966593]]]
rmse: 0.0636
------------------------------------------------------------------
rmse mean: 0.0670
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.0001
seed:0
[[[0.03669312 0.02817953]]

 [[0.10990805 0.09722888]]

 [[0.09299088 0.08402374]]

 [[0.05535445 0.04878027]]]
rmse: 0.0691
------------------------------------------------------------------
seed:1
[[[0.04087701 0.03045845]]

 [[0.11946422 0.09800677]]

 [[0.09126369 0.08255158]]

 [[0.04731171 0.03751293]]]
rmse: 0.0684
------------------------------------------------------------------
seed:2
[[[0.03548095 0.02672069]]

 [[0.1036414  0.08358859]]

 [[0.0881868  0.08350862]]

 [[0.04796633 0.03966909]]]
rmse: 0.0636
------------------------------------------------------------------
rmse mean: 0.0671
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.001
seed:0
[[[0.03667429 0.02813511]]

 [[0.10993467 0.09733432]]

 [[0.09280896 0.08386256]]

 [[0.05422303 0.04410783]]]
rmse: 0.0684
------------------------------------------------------------------
seed:1
[[[0.04096033 0.0305419 ]]

 [[0.102199   0.08630312]]

 [[0.1311833  0.12573097]]

 [[0.04303962 0.03269415]]]
rmse: 0.0741
------------------------------------------------------------------
seed:2
[[[0.03565872 0.02710106]]

 [[0.10345679 0.08358997]]

 [[0.0880117  0.08333881]]

 [[0.04800478 0.03970424]]]
rmse: 0.0636
------------------------------------------------------------------
rmse mean: 0.0687
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=64, alpha=0.01
seed:0
[[[0.03655153 0.02787728]]

 [[0.11437223 0.10419591]]

 [[0.08838801 0.07965342]]

 [[0.04009783 0.03395621]]]
rmse: 0.0656
------------------------------------------------------------------
seed:1
[[[0.03724617 0.02838967]]

 [[0.11880371 0.11118556]]

 [[0.08775561 0.08141276]]

 [[0.04626591 0.03713495]]]
rmse: 0.0685
------------------------------------------------------------------
seed:2
[[[0.0361188  0.027763  ]]

 [[0.10028025 0.09131633]]

 [[0.09272802 0.08593964]]

 [[0.04892546 0.0402371 ]]]
rmse: 0.0654
------------------------------------------------------------------
rmse mean: 0.0665
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.0
seed:0
[[[0.06203738 0.04864598]]

 [[0.0523472  0.04146916]]

 [[0.08794052 0.08064858]]

 [[0.04257162 0.03320581]]]
rmse: 0.0561
------------------------------------------------------------------
seed:1
[[[0.04990725 0.03772319]]

 [[0.10416311 0.08839075]]

 [[0.10253079 0.08939285]]

 [[0.04386767 0.03492364]]]
rmse: 0.0689
------------------------------------------------------------------
seed:2
[[[0.06052295 0.04729338]]

 [[0.08199468 0.06492012]]

 [[0.11268772 0.1075185 ]]

 [[0.04662345 0.03599614]]]
rmse: 0.0697
------------------------------------------------------------------
rmse mean: 0.0649
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.0001
seed:0
[[[0.06193411 0.04856139]]

 [[0.0523589  0.04147292]]

 [[0.08794983 0.08065584]]

 [[0.04257217 0.03320357]]]
rmse: 0.0561
------------------------------------------------------------------
seed:1
[[[0.0498959  0.03771672]]

 [[0.10414915 0.08837913]]

 [[0.10251225 0.08938217]]

 [[0.04386372 0.03492017]]]
rmse: 0.0689
------------------------------------------------------------------
seed:2
[[[0.06052918 0.04729754]]

 [[0.08198332 0.06491245]]

 [[0.11266963 0.10749787]]

 [[0.0465714  0.03594746]]]
rmse: 0.0697
------------------------------------------------------------------
rmse mean: 0.0649
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.001
seed:0
[[[0.06105787 0.04784788]]

 [[0.05234853 0.04144567]]

 [[0.08791079 0.08061061]]

 [[0.04258491 0.03315338]]]
rmse: 0.0559
------------------------------------------------------------------
seed:1
[[[0.04981988 0.03767547]]

 [[0.10408283 0.08831488]]

 [[0.10232074 0.08924727]]

 [[0.04382087 0.0348835 ]]]
rmse: 0.0688
------------------------------------------------------------------
seed:2
[[[0.04454472 0.03418018]]

 [[0.12023964 0.10497893]]

 [[0.10701063 0.09807094]]

 [[0.03678514 0.02995639]]]
rmse: 0.0720
------------------------------------------------------------------
rmse mean: 0.0655
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=128, alpha=0.01
seed:0
[[[0.0416844  0.03376896]]

 [[0.11646112 0.10604757]]

 [[0.10254181 0.09598026]]

 [[0.05068271 0.04031097]]]
rmse: 0.0734
------------------------------------------------------------------
seed:1
[[[0.04055925 0.03236542]]

 [[0.11878614 0.11061179]]

 [[0.09779138 0.08876866]]

 [[0.0448521  0.03808753]]]
rmse: 0.0715
------------------------------------------------------------------
seed:2
[[[0.03929944 0.02905764]]

 [[0.0945357  0.08176639]]

 [[0.0849602  0.07759757]]

 [[0.04463626 0.03598173]]]
rmse: 0.0610
------------------------------------------------------------------
rmse mean: 0.0686
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.0
seed:0
[[[0.039703   0.03189537]]

 [[0.1195124  0.0996314 ]]

 [[0.11623015 0.11084332]]

 [[0.04541164 0.03479073]]]
rmse: 0.0748
------------------------------------------------------------------
seed:1
[[[0.03449727 0.02722601]]

 [[0.10721476 0.08841668]]

 [[0.1177029  0.11400171]]

 [[0.04270915 0.03422874]]]
rmse: 0.0707
------------------------------------------------------------------
seed:2
[[[0.04195985 0.0321877 ]]

 [[0.15576733 0.13417669]]

 [[0.10076142 0.0962698 ]]

 [[0.04269693 0.03433189]]]
rmse: 0.0798
------------------------------------------------------------------
rmse mean: 0.0751
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.0001
seed:0
[[[0.03967852 0.03187941]]

 [[0.11949878 0.09963563]]

 [[0.11622549 0.11083924]]

 [[0.04541099 0.03478959]]]
rmse: 0.0747
------------------------------------------------------------------
seed:1
[[[0.0345171  0.02725136]]

 [[0.10721846 0.08844785]]

 [[0.11766395 0.11396898]]

 [[0.04269731 0.03421991]]]
rmse: 0.0707
------------------------------------------------------------------
seed:2
[[[0.04190986 0.03214667]]

 [[0.15576884 0.13417804]]

 [[0.10080584 0.09631082]]

 [[0.04269561 0.03433465]]]
rmse: 0.0798
------------------------------------------------------------------
rmse mean: 0.0751
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.001
seed:0
[[[0.03943322 0.03170657]]

 [[0.12033269 0.10342121]]

 [[0.11847298 0.11463163]]

 [[0.04040908 0.03166514]]]
rmse: 0.0750
------------------------------------------------------------------
seed:1
[[[0.03461889 0.02736739]]

 [[0.10710767 0.08850171]]

 [[0.11734622 0.11371198]]

 [[0.04258547 0.03413637]]]
rmse: 0.0707
------------------------------------------------------------------
seed:2
[[[0.04159864 0.03184877]]

 [[0.15582008 0.13423553]]

 [[0.10093781 0.09643068]]

 [[0.0426261  0.03431304]]]
rmse: 0.0797
------------------------------------------------------------------
rmse mean: 0.0751
===================================================================
K=64, lr=0.001, num_layers=1, hidden_dim=256, alpha=0.01
seed:0
[[[0.0374298  0.0303026 ]]

 [[0.12259533 0.11325756]]

 [[0.10988338 0.10688854]]

 [[0.04270131 0.03386448]]]
rmse: 0.0746
------------------------------------------------------------------
seed:1
[[[0.03572391 0.02802735]]

 [[0.10298737 0.08985843]]

 [[0.0985839  0.09445266]]

 [[0.04674831 0.03878   ]]]
rmse: 0.0669
------------------------------------------------------------------
seed:2
[[[0.03836771 0.02865967]]

 [[0.10486469 0.09413264]]

 [[0.10067199 0.09468164]]

 [[0.04689026 0.0388564 ]]]
rmse: 0.0684
------------------------------------------------------------------
rmse mean: 0.0700
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.0
seed:0
[[[0.03424598 0.02712757]]

 [[0.10141319 0.08172909]]

 [[0.10569061 0.10082323]]

 [[0.04307516 0.03349352]]]
rmse: 0.0659
------------------------------------------------------------------
seed:1
[[[0.04965943 0.04101911]]

 [[0.11794168 0.10442629]]

 [[0.09372701 0.08470727]]

 [[0.03720286 0.02938672]]]
rmse: 0.0698
------------------------------------------------------------------
seed:2
[[[0.06346486 0.05301039]]

 [[0.11156384 0.09120748]]

 [[0.11103752 0.10533885]]

 [[0.03716666 0.02931393]]]
rmse: 0.0753
------------------------------------------------------------------
rmse mean: 0.0703
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.0001
seed:0
[[[0.03419798 0.02709127]]

 [[0.1013306  0.0816523 ]]

 [[0.10567412 0.10080836]]

 [[0.04331441 0.0337477 ]]]
rmse: 0.0660
------------------------------------------------------------------
seed:1
[[[0.04960571 0.04096709]]

 [[0.11785    0.10434078]]

 [[0.09373116 0.08471192]]

 [[0.03720469 0.02939403]]]
rmse: 0.0697
------------------------------------------------------------------
seed:2
[[[0.06334302 0.05292152]]

 [[0.1115611  0.09119911]]

 [[0.11148923 0.10576784]]

 [[0.03719856 0.02935429]]]
rmse: 0.0754
------------------------------------------------------------------
rmse mean: 0.0704
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.001
seed:0
[[[0.03384629 0.02680185]]

 [[0.102246   0.0825189 ]]

 [[0.10551698 0.10066729]]

 [[0.04305471 0.03345719]]]
rmse: 0.0660
------------------------------------------------------------------
seed:1
[[[0.04914725 0.04055345]]

 [[0.11812028 0.10465583]]

 [[0.10506242 0.09931913]]

 [[0.04600037 0.03798498]]]
rmse: 0.0751
------------------------------------------------------------------
seed:2
[[[0.06230846 0.05210299]]

 [[0.10089972 0.08378302]]

 [[0.11758594 0.11293986]]

 [[0.05360502 0.04368719]]]
rmse: 0.0784
------------------------------------------------------------------
rmse mean: 0.0732
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=16, alpha=0.01
seed:0
[[[0.0331633  0.0261546 ]]

 [[0.10748905 0.09231259]]

 [[0.07913882 0.07278522]]

 [[0.04062729 0.0338253 ]]]
rmse: 0.0607
------------------------------------------------------------------
seed:1
[[[0.04175225 0.03356631]]

 [[0.12236876 0.11316135]]

 [[0.07876922 0.07000567]]

 [[0.03863121 0.03255859]]]
rmse: 0.0664
------------------------------------------------------------------
seed:2
[[[0.05359664 0.04422841]]

 [[0.12542831 0.11845257]]

 [[0.08980237 0.08274498]]

 [[0.03717835 0.0294164 ]]]
rmse: 0.0726
------------------------------------------------------------------
rmse mean: 0.0665
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.0
seed:0
[[[0.04594538 0.0329802 ]]

 [[0.11111422 0.09743902]]

 [[0.11448983 0.10689902]]

 [[0.05100246 0.04156651]]]
rmse: 0.0752
------------------------------------------------------------------
seed:1
[[[0.04472407 0.03716676]]

 [[0.12782842 0.1016594 ]]

 [[0.09706466 0.09225743]]

 [[0.04282898 0.03335226]]]
rmse: 0.0721
------------------------------------------------------------------
seed:2
[[[0.02985312 0.02310162]]

 [[0.126257   0.10239199]]

 [[0.11050846 0.10762692]]

 [[0.04106133 0.03170644]]]
rmse: 0.0716
------------------------------------------------------------------
rmse mean: 0.0730
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.0001
seed:0
[[[0.04588474 0.03292509]]

 [[0.11117666 0.09748249]]

 [[0.11449128 0.10690482]]

 [[0.05099053 0.04155573]]]
rmse: 0.0752
------------------------------------------------------------------
seed:1
[[[0.04468761 0.03713657]]

 [[0.12780849 0.10164452]]

 [[0.0970646  0.09225694]]

 [[0.0428281  0.03335234]]]
rmse: 0.0721
------------------------------------------------------------------
seed:2
[[[0.03029499 0.02233293]]

 [[0.11776955 0.09174975]]

 [[0.08303644 0.07912674]]

 [[0.04831628 0.03777896]]]
rmse: 0.0638
------------------------------------------------------------------
rmse mean: 0.0704
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.001
seed:0
[[[0.04535306 0.03243862]]

 [[0.11056564 0.09762123]]

 [[0.11567701 0.10829871]]

 [[0.04387428 0.03436129]]]
rmse: 0.0735
------------------------------------------------------------------
seed:1
[[[0.04441542 0.03691904]]

 [[0.12921866 0.10066668]]

 [[0.10373414 0.09927638]]

 [[0.04915841 0.03926904]]]
rmse: 0.0753
------------------------------------------------------------------
seed:2
[[[0.04469604 0.03231792]]

 [[0.11653579 0.10143973]]

 [[0.10687424 0.10215836]]

 [[0.04115026 0.03077142]]]
rmse: 0.0720
------------------------------------------------------------------
rmse mean: 0.0736
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=32, alpha=0.01
seed:0
[[[0.04076638 0.03287688]]

 [[0.11573152 0.09391088]]

 [[0.08613877 0.08027364]]

 [[0.03729039 0.03174252]]]
rmse: 0.0648
------------------------------------------------------------------
seed:1
[[[0.03502014 0.02806218]]

 [[0.10638187 0.09490158]]

 [[0.10052199 0.09657543]]

 [[0.03893901 0.03028909]]]
rmse: 0.0663
------------------------------------------------------------------
seed:2
[[[0.0353141  0.02734915]]

 [[0.11207355 0.10404815]]

 [[0.1061825  0.1010897 ]]

 [[0.042063   0.0360296 ]]]
rmse: 0.0705
------------------------------------------------------------------
rmse mean: 0.0672
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.0
seed:0
[[[0.0366537  0.03058234]]

 [[0.1161869  0.09432066]]

 [[0.09846039 0.09265463]]

 [[0.04257441 0.03205512]]]
rmse: 0.0679
------------------------------------------------------------------
seed:1
[[[0.0382017  0.02833796]]

 [[0.12226177 0.0999878 ]]

 [[0.1310305  0.12506729]]

 [[0.04307198 0.03416497]]]
rmse: 0.0778
------------------------------------------------------------------
seed:2
[[[0.0442897  0.03212302]]

 [[0.11362128 0.09132415]]

 [[0.10340927 0.10014616]]

 [[0.04100023 0.03156522]]]
rmse: 0.0697
------------------------------------------------------------------
rmse mean: 0.0718
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.0001
seed:0
[[[0.03664247 0.03057497]]

 [[0.11611726 0.09427671]]

 [[0.09845529 0.09265049]]

 [[0.04255882 0.03204487]]]
rmse: 0.0679
------------------------------------------------------------------
seed:1
[[[0.03818365 0.02833002]]

 [[0.12227764 0.10000647]]

 [[0.13095296 0.12498095]]

 [[0.04305631 0.03415375]]]
rmse: 0.0777
------------------------------------------------------------------
seed:2
[[[0.0442261  0.03204954]]

 [[0.11358205 0.09132257]]

 [[0.10336191 0.10009423]]

 [[0.04097692 0.03153675]]]
rmse: 0.0696
------------------------------------------------------------------
rmse mean: 0.0718
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.001
seed:0
[[[0.03656187 0.03052225]]

 [[0.11552191 0.09391175]]

 [[0.09848516 0.09268557]]

 [[0.04240308 0.03193809]]]
rmse: 0.0678
------------------------------------------------------------------
seed:1
[[[0.03803487 0.02826201]]

 [[0.12244396 0.10022702]]

 [[0.09860639 0.09409676]]

 [[0.05133843 0.03877602]]]
rmse: 0.0715
------------------------------------------------------------------
seed:2
[[[0.04319592 0.03104077]]

 [[0.11307908 0.09126224]]

 [[0.10292124 0.09962715]]

 [[0.04135211 0.03192134]]]
rmse: 0.0693
------------------------------------------------------------------
rmse mean: 0.0695
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=64, alpha=0.01
seed:0
[[[0.03617851 0.03041526]]

 [[0.11069474 0.09140405]]

 [[0.09682359 0.09295161]]

 [[0.03845475 0.03160974]]]
rmse: 0.0661
------------------------------------------------------------------
seed:1
[[[0.037155   0.02769948]]

 [[0.11879505 0.09892065]]

 [[0.0757914  0.06673987]]

 [[0.03897657 0.0311225 ]]]
rmse: 0.0619
------------------------------------------------------------------
seed:2
[[[0.03864114 0.02667742]]

 [[0.11159074 0.09141107]]

 [[0.08629926 0.08054817]]

 [[0.03919115 0.03153964]]]
rmse: 0.0632
------------------------------------------------------------------
rmse mean: 0.0637
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.0
seed:0
[[[0.03250204 0.02556184]]

 [[0.11326178 0.09660289]]

 [[0.11856512 0.11439748]]

 [[0.03798483 0.03005294]]]
rmse: 0.0711
------------------------------------------------------------------
seed:1
[[[0.04825574 0.03523093]]

 [[0.11015243 0.08595656]]

 [[0.12104885 0.11668509]]

 [[0.04783969 0.03871704]]]
rmse: 0.0755
------------------------------------------------------------------
seed:2
[[[0.034149   0.02563278]]

 [[0.10680694 0.08774474]]

 [[0.11098343 0.10657549]]

 [[0.04281891 0.03413618]]]
rmse: 0.0686
------------------------------------------------------------------
rmse mean: 0.0717
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.0001
seed:0
[[[0.03245984 0.02551783]]

 [[0.11340096 0.09667947]]

 [[0.11853635 0.11437557]]

 [[0.03798587 0.03005607]]]
rmse: 0.0711
------------------------------------------------------------------
seed:1
[[[0.04822933 0.03519988]]

 [[0.11082734 0.0863192 ]]

 [[0.12105642 0.11669005]]

 [[0.04783749 0.03871628]]]
rmse: 0.0756
------------------------------------------------------------------
seed:2
[[[0.03416625 0.02564638]]

 [[0.10664352 0.08765944]]

 [[0.1109709  0.10657015]]

 [[0.04279849 0.03411812]]]
rmse: 0.0686
------------------------------------------------------------------
rmse mean: 0.0718
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.001
seed:0
[[[0.03215933 0.02521872]]

 [[0.11337293 0.09669653]]

 [[0.11836526 0.11430368]]

 [[0.03796504 0.03005437]]]
rmse: 0.0710
------------------------------------------------------------------
seed:1
[[[0.04801229 0.03494794]]

 [[0.12891653 0.10106248]]

 [[0.06348454 0.0570968 ]]

 [[0.04093099 0.03217143]]]
rmse: 0.0633
------------------------------------------------------------------
seed:2
[[[0.03429896 0.02574906]]

 [[0.10596791 0.08733765]]

 [[0.11086148 0.10649746]]

 [[0.04257554 0.03392174]]]
rmse: 0.0684
------------------------------------------------------------------
rmse mean: 0.0676
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=128, alpha=0.01
seed:0
[[[0.05005502 0.03953847]]

 [[0.11274326 0.09724666]]

 [[0.11249882 0.10758939]]

 [[0.03898647 0.03159995]]]
rmse: 0.0738
------------------------------------------------------------------
seed:1
[[[0.04639825 0.03334006]]

 [[0.10524064 0.08484893]]

 [[0.11502022 0.11176383]]

 [[0.03814866 0.03179027]]]
rmse: 0.0708
------------------------------------------------------------------
seed:2
[[[0.03521196 0.0266006 ]]

 [[0.1165129  0.09578937]]

 [[0.10049583 0.09799242]]

 [[0.04085799 0.03348853]]]
rmse: 0.0684
------------------------------------------------------------------
rmse mean: 0.0710
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.0
seed:0
[[[0.03107598 0.02281557]]

 [[0.11190404 0.09063903]]

 [[0.12034818 0.1153186 ]]

 [[0.047423   0.04005721]]]
rmse: 0.0724
------------------------------------------------------------------
seed:1
[[[0.03748881 0.02763224]]

 [[0.11210609 0.0984066 ]]

 [[0.12793305 0.11560767]]

 [[0.03900482 0.03091919]]]
rmse: 0.0736
------------------------------------------------------------------
seed:2
[[[0.02944722 0.02190771]]

 [[0.12631375 0.10194289]]

 [[0.0724766  0.06576011]]

 [[0.03995225 0.03252666]]]
rmse: 0.0613
------------------------------------------------------------------
rmse mean: 0.0691
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.0001
seed:0
[[[0.02857166 0.02132435]]

 [[0.10926924 0.09051808]]

 [[0.11682264 0.11116308]]

 [[0.03762487 0.02957679]]]
rmse: 0.0681
------------------------------------------------------------------
seed:1
[[[0.03742154 0.02758757]]

 [[0.11205667 0.09839977]]

 [[0.12788573 0.11558439]]

 [[0.03891517 0.03090637]]]
rmse: 0.0736
------------------------------------------------------------------
seed:2
[[[0.02941852 0.02189089]]

 [[0.1262042  0.10187819]]

 [[0.07128479 0.06460331]]

 [[0.03995265 0.03252752]]]
rmse: 0.0610
------------------------------------------------------------------
rmse mean: 0.0676
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.001
seed:0
[[[0.03035979 0.02235192]]

 [[0.1106132  0.09651014]]

 [[0.11303625 0.11026651]]

 [[0.04718788 0.03981019]]]
rmse: 0.0713
------------------------------------------------------------------
seed:1
[[[0.03704098 0.0272678 ]]

 [[0.11827173 0.09953184]]

 [[0.11829531 0.11353938]]

 [[0.03938571 0.03090294]]]
rmse: 0.0730
------------------------------------------------------------------
seed:2
[[[0.02898147 0.02146164]]

 [[0.12569821 0.10159445]]

 [[0.11511447 0.11214428]]

 [[0.03626908 0.02879375]]]
rmse: 0.0713
------------------------------------------------------------------
rmse mean: 0.0719
===================================================================
K=64, lr=0.001, num_layers=2, hidden_dim=256, alpha=0.01
seed:0
[[[0.03617729 0.02809219]]

 [[0.11695542 0.09710897]]

 [[0.10317124 0.10119966]]

 [[0.04002701 0.03140069]]]
rmse: 0.0693
------------------------------------------------------------------
seed:1
[[[0.03435514 0.02500453]]

 [[0.11424118 0.09876574]]

 [[0.11534255 0.11032772]]

 [[0.0400623  0.03063305]]]
rmse: 0.0711
------------------------------------------------------------------
seed:2
[[[0.03324239 0.02477723]]

 [[0.12275564 0.10039122]]

 [[0.10743358 0.10485397]]

 [[0.038286   0.03128075]]]
rmse: 0.0704
------------------------------------------------------------------
rmse mean: 0.0702
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.0
seed:0
[[[0.05416025 0.04458662]]

 [[0.12981073 0.10702106]]

 [[0.08549915 0.07437667]]

 [[0.04343487 0.0357731 ]]]
rmse: 0.0718
------------------------------------------------------------------
seed:1
[[[0.04613388 0.03657698]]

 [[0.12010444 0.09843215]]

 [[0.11437803 0.1101861 ]]

 [[0.04123893 0.03097258]]]
rmse: 0.0748
------------------------------------------------------------------
seed:2
[[[0.04256513 0.03276562]]

 [[0.10759051 0.08589018]]

 [[0.12222746 0.11657763]]

 [[0.03907658 0.03123393]]]
rmse: 0.0722
------------------------------------------------------------------
rmse mean: 0.0729
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.0001
seed:0
[[[0.05416033 0.04458691]]

 [[0.12980939 0.10701981]]

 [[0.08549742 0.07437478]]

 [[0.0434347  0.03577301]]]
rmse: 0.0718
------------------------------------------------------------------
seed:1
[[[0.04613385 0.03657717]]

 [[0.12010273 0.09843002]]

 [[0.11437751 0.11018569]]

 [[0.04123744 0.03097122]]]
rmse: 0.0748
------------------------------------------------------------------
seed:2
[[[0.04256497 0.03276544]]

 [[0.10762946 0.08592869]]

 [[0.1222278  0.11657807]]

 [[0.03907354 0.03122938]]]
rmse: 0.0722
------------------------------------------------------------------
rmse mean: 0.0729
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.001
seed:0
[[[0.05416044 0.04458685]]

 [[0.12981163 0.10702147]]

 [[0.08548506 0.07436218]]

 [[0.04343387 0.03577241]]]
rmse: 0.0718
------------------------------------------------------------------
seed:1
[[[0.04613316 0.03657819]]

 [[0.12002018 0.0983492 ]]

 [[0.1143743  0.11018413]]

 [[0.04123663 0.03096984]]]
rmse: 0.0747
------------------------------------------------------------------
seed:2
[[[0.04256349 0.03276403]]

 [[0.10778416 0.08608419]]

 [[0.12222893 0.11657927]]

 [[0.03904692 0.0311896 ]]]
rmse: 0.0723
------------------------------------------------------------------
rmse mean: 0.0729
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=16, alpha=0.01
seed:0
[[[0.05417708 0.04461597]]

 [[0.12980873 0.10701978]]

 [[0.08534235 0.07421067]]

 [[0.04342037 0.03576007]]]
rmse: 0.0718
------------------------------------------------------------------
seed:1
[[[0.04612851 0.0365825 ]]

 [[0.1188448  0.09716753]]

 [[0.11435484 0.11017418]]

 [[0.04117263 0.03089726]]]
rmse: 0.0744
------------------------------------------------------------------
seed:2
[[[0.04254952 0.03275053]]

 [[0.10962657 0.0879476 ]]

 [[0.12224874 0.11660168]]

 [[0.03881231 0.03087   ]]]
rmse: 0.0727
------------------------------------------------------------------
rmse mean: 0.0730
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.0
seed:0
[[[0.06736412 0.05616719]]

 [[0.10573963 0.09384666]]

 [[0.11503607 0.10975465]]

 [[0.04230327 0.03273379]]]
rmse: 0.0779
------------------------------------------------------------------
seed:1
[[[0.02760684 0.01754064]]

 [[0.12608249 0.10327259]]

 [[0.11794989 0.11162065]]

 [[0.04057206 0.03277273]]]
rmse: 0.0722
------------------------------------------------------------------
seed:2
[[[0.06040302 0.04991287]]

 [[0.12151095 0.10081537]]

 [[0.11485777 0.10972547]]

 [[0.04246872 0.03408634]]]
rmse: 0.0792
------------------------------------------------------------------
rmse mean: 0.0764
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.0001
seed:0
[[[0.06736403 0.05616707]]

 [[0.10573667 0.09384201]]

 [[0.1150343  0.1097526 ]]

 [[0.04230295 0.03273357]]]
rmse: 0.0779
------------------------------------------------------------------
seed:1
[[[0.02760984 0.01754195]]

 [[0.12609289 0.10330333]]

 [[0.11795002 0.11162038]]

 [[0.04057013 0.0327703 ]]]
rmse: 0.0722
------------------------------------------------------------------
seed:2
[[[0.06040293 0.0499129 ]]

 [[0.12151155 0.10081592]]

 [[0.11485868 0.10972638]]

 [[0.04246852 0.03408608]]]
rmse: 0.0792
------------------------------------------------------------------
rmse mean: 0.0764
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.001
seed:0
[[[0.06736384 0.05616696]]

 [[0.10571608 0.09380912]]

 [[0.11503456 0.10975248]]

 [[0.04229945 0.03273108]]]
rmse: 0.0779
------------------------------------------------------------------
seed:1
[[[0.02752598 0.02129842]]

 [[0.11931559 0.09703895]]

 [[0.11361446 0.1074801 ]]

 [[0.04052685 0.03293735]]]
rmse: 0.0700
------------------------------------------------------------------
seed:2
[[[0.06039883 0.04990933]]

 [[0.12151227 0.10081625]]

 [[0.11485943 0.10972588]]

 [[0.04246664 0.03408429]]]
rmse: 0.0792
------------------------------------------------------------------
rmse mean: 0.0757
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=32, alpha=0.01
seed:0
[[[0.06735925 0.05616256]]

 [[0.10553509 0.09350124]]

 [[0.11503783 0.10975074]]

 [[0.04226282 0.03270522]]]
rmse: 0.0778
------------------------------------------------------------------
seed:1
[[[0.0662097  0.05552609]]

 [[0.12113437 0.09880131]]

 [[0.11921498 0.11362525]]

 [[0.04414546 0.0366358 ]]]
rmse: 0.0819
------------------------------------------------------------------
seed:2
[[[0.06035381 0.04986936]]

 [[0.12151533 0.10081568]]

 [[0.11486663 0.10972225]]

 [[0.04244832 0.03406667]]]
rmse: 0.0792
------------------------------------------------------------------
rmse mean: 0.0796
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.0
seed:0
[[[0.04512793 0.03376083]]

 [[0.16158468 0.13580281]]

 [[0.12306878 0.1166546 ]]

 [[0.04023382 0.03023263]]]
rmse: 0.0858
------------------------------------------------------------------
seed:1
[[[0.04887764 0.03734469]]

 [[0.11952151 0.09919268]]

 [[0.12121113 0.11589988]]

 [[0.04101458 0.03355635]]]
rmse: 0.0771
------------------------------------------------------------------
seed:2
[[[0.05310701 0.04108914]]

 [[0.09869289 0.07856385]]

 [[0.12569552 0.12069859]]

 [[0.04533041 0.03720459]]]
rmse: 0.0750
------------------------------------------------------------------
rmse mean: 0.0793
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.0001
seed:0
[[[0.04512596 0.0337633 ]]

 [[0.14361147 0.11592219]]

 [[0.11853786 0.11324987]]

 [[0.04053147 0.03350973]]]
rmse: 0.0805
------------------------------------------------------------------
seed:1
[[[0.04423823 0.03186722]]

 [[0.12158637 0.10073599]]

 [[0.12087435 0.11575886]]

 [[0.04154123 0.03270595]]]
rmse: 0.0762
------------------------------------------------------------------
seed:2
[[[0.0531068  0.0410978 ]]

 [[0.09904786 0.07890216]]

 [[0.12577205 0.12079064]]

 [[0.04533092 0.03720544]]]
rmse: 0.0752
------------------------------------------------------------------
rmse mean: 0.0773
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.001
seed:0
[[[0.04512313 0.03375646]]

 [[0.11050701 0.08503254]]

 [[0.11882439 0.1134088 ]]

 [[0.040531   0.03350915]]]
rmse: 0.0726
------------------------------------------------------------------
seed:1
[[[0.04970813 0.03990197]]

 [[0.12787602 0.10435163]]

 [[0.11675595 0.11052652]]

 [[0.04319546 0.03503167]]]
rmse: 0.0784
------------------------------------------------------------------
seed:2
[[[0.05310048 0.04118111]]

 [[0.13146292 0.11156631]]

 [[0.11905471 0.11389117]]

 [[0.04539014 0.03713368]]]
rmse: 0.0816
------------------------------------------------------------------
rmse mean: 0.0775
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=64, alpha=0.01
seed:0
[[[0.04507581 0.03369589]]

 [[0.1313228  0.1079263 ]]

 [[0.12059796 0.11615244]]

 [[0.04551762 0.03695905]]]
rmse: 0.0797
------------------------------------------------------------------
seed:1
[[[0.04316654 0.03228677]]

 [[0.13127495 0.1106734 ]]

 [[0.12131348 0.11546702]]

 [[0.0400781  0.03304379]]]
rmse: 0.0784
------------------------------------------------------------------
seed:2
[[[0.05318562 0.04195465]]

 [[0.10710623 0.08618769]]

 [[0.1218788  0.11728255]]

 [[0.04039202 0.03204341]]]
rmse: 0.0750
------------------------------------------------------------------
rmse mean: 0.0777
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.0
seed:0
[[[0.05338583 0.04284652]]

 [[0.12271557 0.10279688]]

 [[0.11245138 0.1077808 ]]

 [[0.0462181  0.03700616]]]
rmse: 0.0782
------------------------------------------------------------------
seed:1
[[[0.04438713 0.03507872]]

 [[0.13078468 0.10616415]]

 [[0.11717646 0.11155049]]

 [[0.04233886 0.0305943 ]]]
rmse: 0.0773
------------------------------------------------------------------
seed:2
[[[0.05294986 0.0438733 ]]

 [[0.11026615 0.09070297]]

 [[0.12089839 0.11468247]]

 [[0.0405841  0.03148821]]]
rmse: 0.0757
------------------------------------------------------------------
rmse mean: 0.0770
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.0001
seed:0
[[[0.05359162 0.04308745]]

 [[0.12302046 0.10320237]]

 [[0.11184773 0.10689547]]

 [[0.04621786 0.03700612]]]
rmse: 0.0781
------------------------------------------------------------------
seed:1
[[[0.04438742 0.0350797 ]]

 [[0.13078359 0.1061632 ]]

 [[0.11717479 0.11154944]]

 [[0.0428515  0.03103674]]]
rmse: 0.0774
------------------------------------------------------------------
seed:2
[[[0.05295493 0.04387801]]

 [[0.11065246 0.09122698]]

 [[0.12089513 0.11467964]]

 [[0.04061981 0.0315269 ]]]
rmse: 0.0758
------------------------------------------------------------------
rmse mean: 0.0771
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.001
seed:0
[[[0.05709867 0.0467494 ]]

 [[0.12225718 0.10213541]]

 [[0.11166347 0.10663215]]

 [[0.04088243 0.03151994]]]
rmse: 0.0774
------------------------------------------------------------------
seed:1
[[[0.0443884  0.03507928]]

 [[0.13077983 0.10616067]]

 [[0.11716599 0.11154707]]

 [[0.04230748 0.03056654]]]
rmse: 0.0772
------------------------------------------------------------------
seed:2
[[[0.05295313 0.04387633]]

 [[0.11033145 0.09082678]]

 [[0.12088176 0.11467205]]

 [[0.04069872 0.03161333]]]
rmse: 0.0757
------------------------------------------------------------------
rmse mean: 0.0768
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=128, alpha=0.01
seed:0
[[[0.0569545  0.04660725]]

 [[0.12498218 0.10295416]]

 [[0.11207657 0.10803444]]

 [[0.04614416 0.03693946]]]
rmse: 0.0793
------------------------------------------------------------------
seed:1
[[[0.04438928 0.03508789]]

 [[0.13075599 0.1061436 ]]

 [[0.11693598 0.11144401]]

 [[0.04301609 0.03119621]]]
rmse: 0.0774
------------------------------------------------------------------
seed:2
[[[0.05295603 0.04387965]]

 [[0.11401438 0.09331824]]

 [[0.12075354 0.11461325]]

 [[0.04063856 0.03154307]]]
rmse: 0.0765
------------------------------------------------------------------
rmse mean: 0.0777
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.0
seed:0
[[[0.02896942 0.0208567 ]]

 [[0.16426448 0.13471699]]

 [[0.1242992  0.11985985]]

 [[0.03679065 0.0277398 ]]]
rmse: 0.0822
------------------------------------------------------------------
seed:1
[[[0.05160221 0.04121845]]

 [[0.11529634 0.09167258]]

 [[0.11284535 0.106675  ]]

 [[0.04227984 0.03383225]]]
rmse: 0.0744
------------------------------------------------------------------
seed:2
[[[0.0474338  0.03855859]]

 [[0.15352286 0.12511842]]

 [[0.10698423 0.10185412]]

 [[0.04008227 0.03160964]]]
rmse: 0.0806
------------------------------------------------------------------
rmse mean: 0.0791
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.0001
seed:0
[[[0.02900515 0.02091376]]

 [[0.16597111 0.13634792]]

 [[0.12430388 0.11986289]]

 [[0.03678944 0.02774214]]]
rmse: 0.0826
------------------------------------------------------------------
seed:1
[[[0.0638988  0.05338996]]

 [[0.12498976 0.10381167]]

 [[0.0491417  0.03879107]]

 [[0.03586326 0.02884468]]]
rmse: 0.0623
------------------------------------------------------------------
seed:2
[[[0.04856167 0.03985608]]

 [[0.15606955 0.12734122]]

 [[0.10767199 0.10244274]]

 [[0.0459171  0.03754728]]]
rmse: 0.0832
------------------------------------------------------------------
rmse mean: 0.0760
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.001
seed:0
[[[0.02855812 0.02046068]]

 [[0.16596771 0.13634279]]

 [[0.12457126 0.1199916 ]]

 [[0.03643889 0.02795774]]]
rmse: 0.0825
------------------------------------------------------------------
seed:1
[[[0.03984341 0.0287472 ]]

 [[0.11531903 0.09169243]]

 [[0.11284913 0.10667788]]

 [[0.03837705 0.02991022]]]
rmse: 0.0704
------------------------------------------------------------------
seed:2
[[[0.05004435 0.04142471]]

 [[0.16329376 0.13395309]]

 [[0.1352758  0.1318335 ]]

 [[0.03846486 0.03145166]]]
rmse: 0.0907
------------------------------------------------------------------
rmse mean: 0.0812
===================================================================
K=64, lr=0.01, num_layers=1, hidden_dim=256, alpha=0.01
seed:0
[[[0.02835106 0.02066304]]

 [[0.14716313 0.12131061]]

 [[0.12462823 0.12001848]]

 [[0.03680151 0.02766883]]]
rmse: 0.0783
------------------------------------------------------------------
seed:1
[[[0.04263411 0.03242212]]

 [[0.17162895 0.14219881]]

 [[0.13406119 0.11348584]]

 [[0.04402049 0.03678455]]]
rmse: 0.0897
------------------------------------------------------------------
seed:2
[[[0.07547585 0.05956211]]

 [[0.11005187 0.09270547]]

 [[0.11337239 0.10832706]]

 [[0.318957   0.30609926]]]
rmse: 0.1481
------------------------------------------------------------------
rmse mean: 0.1053
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.0
seed:0
[[[0.13417648 0.11583934]]

 [[0.17369372 0.14428191]]

 [[0.14545863 0.11615487]]

 [[0.09068696 0.07899516]]]
rmse: 0.1249
------------------------------------------------------------------
seed:1
[[[0.04811456 0.04383224]]

 [[0.17338522 0.14403325]]

 [[0.15245597 0.1179826 ]]

 [[0.15695655 0.14134347]]]
rmse: 0.1223
------------------------------------------------------------------
seed:2
[[[0.05403352 0.04284887]]

 [[0.17300469 0.14356898]]

 [[0.15213347 0.11794229]]

 [[0.03748014 0.03264154]]]
rmse: 0.0942
------------------------------------------------------------------
rmse mean: 0.1138
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.0001
seed:0
[[[0.13418107 0.11584329]]

 [[0.17370004 0.14428712]]

 [[0.14884855 0.11699092]]

 [[0.09034168 0.07868296]]]
rmse: 0.1254
------------------------------------------------------------------
seed:1
[[[0.10279132 0.09348102]]

 [[0.1731447  0.14370847]]

 [[0.15189694 0.11782062]]

 [[0.03568864 0.02816701]]]
rmse: 0.1058
------------------------------------------------------------------
seed:2
[[[0.03222617 0.02307358]]

 [[0.1694567  0.13999132]]

 [[0.15076275 0.11744027]]

 [[0.08907243 0.07741017]]]
rmse: 0.0999
------------------------------------------------------------------
rmse mean: 0.1104
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.001
seed:0
[[[0.13417114 0.11583491]]

 [[0.17372117 0.14430602]]

 [[0.14528296 0.11627954]]

 [[1.17287586 1.1693364 ]]]
rmse: 0.3965
------------------------------------------------------------------
seed:1
[[[0.13320275 0.1150633 ]]

 [[0.17330335 0.14387046]]

 [[0.15051686 0.11748954]]

 [[0.08691849 0.0753416 ]]]
rmse: 0.1245
------------------------------------------------------------------
seed:2
[[[0.16246821 0.13163008]]

 [[0.21790099 0.18627569]]

 [[0.15075659 0.11743858]]

 [[0.03896829 0.03099115]]]
rmse: 0.1296
------------------------------------------------------------------
rmse mean: 0.2168
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=16, alpha=0.01
seed:0
[[[0.13419021 0.11585121]]

 [[0.17368759 0.14427713]]

 [[0.14543886 0.11615745]]

 [[0.08755203 0.06812743]]]
rmse: 0.1232
------------------------------------------------------------------
seed:1
[[[0.13339891 0.11522035]]

 [[0.17330121 0.14386823]]

 [[0.15054089 0.1175029 ]]

 [[0.08682426 0.07525064]]]
rmse: 0.1245
------------------------------------------------------------------
seed:2
[[[0.18961154 0.15202967]]

 [[0.17248639 0.14301416]]

 [[0.15055659 0.11746149]]

 [[0.09083233 0.07912312]]]
rmse: 0.1369
------------------------------------------------------------------
rmse mean: 0.1282
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.0
seed:0
[[[0.1329426  0.11485889]]

 [[0.1736019  0.14415089]]

 [[0.13006253 0.11476005]]

 [[0.03510128 0.02871905]]]
rmse: 0.1093
------------------------------------------------------------------
seed:1
[[[0.1348966  0.11643451]]

 [[0.17057501 0.14103933]]

 [[0.15451854 0.11993113]]

 [[0.09108079 0.07935012]]]
rmse: 0.1260
------------------------------------------------------------------
seed:2
[[[0.13342754 0.11442761]]

 [[0.17433391 0.14492052]]

 [[0.15237139 0.11795814]]

 [[0.09084411 0.07915056]]]
rmse: 0.1259
------------------------------------------------------------------
rmse mean: 0.1204
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.0001
seed:0
[[[0.13293666 0.11485421]]

 [[0.17364462 0.14419815]]

 [[0.14524799 0.11609084]]

 [[0.10965157 0.10223162]]]
rmse: 0.1299
------------------------------------------------------------------
seed:1
[[[0.1348964  0.11643375]]

 [[0.17024587 0.14069833]]

 [[0.15541129 0.12079054]]

 [[0.09107928 0.07935063]]]
rmse: 0.1261
------------------------------------------------------------------
seed:2
[[[0.13641062 0.1170944 ]]

 [[0.17435162 0.14493732]]

 [[0.15237634 0.11796041]]

 [[0.09078601 0.07909176]]]
rmse: 0.1266
------------------------------------------------------------------
rmse mean: 0.1275
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.001
seed:0
[[[0.13292935 0.11484976]]

 [[0.17362013 0.14417801]]

 [[0.14546481 0.11613975]]

 [[0.08568797 0.0743293 ]]]
rmse: 0.1234
------------------------------------------------------------------
seed:1
[[[0.13489636 0.11642898]]

 [[0.17150383 0.1420028 ]]

 [[0.15659288 0.12187836]]

 [[0.0910773  0.07934716]]]
rmse: 0.1267
------------------------------------------------------------------
seed:2
[[[0.1195413  0.10351069]]

 [[0.17477025 0.14532713]]

 [[0.15236634 0.11795721]]

 [[0.0908834  0.07919039]]]
rmse: 0.1229
------------------------------------------------------------------
rmse mean: 0.1244
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=32, alpha=0.01
seed:0
[[[0.13289535 0.11482527]]

 [[0.17362462 0.14418003]]

 [[0.14639721 0.11636965]]

 [[0.08465644 0.07319742]]]
rmse: 0.1233
------------------------------------------------------------------
seed:1
[[[0.13489604 0.1164311 ]]

 [[0.17013009 0.14058011]]

 [[0.15425751 0.11969825]]

 [[0.09116383 0.07942723]]]
rmse: 0.1258
------------------------------------------------------------------
seed:2
[[[0.13335377 0.11522061]]

 [[0.17441764 0.14499862]]

 [[0.15238513 0.11796521]]

 [[0.09091851 0.0792184 ]]]
rmse: 0.1261
------------------------------------------------------------------
rmse mean: 0.1251
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.0
seed:0
[[[0.1353747  0.11674232]]

 [[0.33081133 0.29821538]]

 [[0.15256946 0.11802026]]

 [[0.09156215 0.07979163]]]
rmse: 0.1654
------------------------------------------------------------------
seed:1
[[[0.13541956 0.11689453]]

 [[0.17430354 0.14488765]]

 [[0.15145865 0.11765396]]

 [[0.09132423 0.07959944]]]
rmse: 0.1264
------------------------------------------------------------------
seed:2
[[[0.13520452 0.11670253]]

 [[0.17421502 0.1448039 ]]

 [[0.15238732 0.1179685 ]]

 [[0.03630924 0.03104119]]]
rmse: 0.1136
------------------------------------------------------------------
rmse mean: 0.1351
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.0001
seed:0
[[[0.13542593 0.11686629]]

 [[0.32908322 0.29654739]]

 [[0.15253948 0.11805367]]

 [[0.09167387 0.07993873]]]
rmse: 0.1650
------------------------------------------------------------------
seed:1
[[[0.13542299 0.11689872]]

 [[0.17815306 0.14844255]]

 [[0.15159144 0.11776299]]

 [[0.09133015 0.07960564]]]
rmse: 0.1274
------------------------------------------------------------------
seed:2
[[[0.1351845  0.11667898]]

 [[0.17178157 0.14259957]]

 [[0.15240648 0.11797723]]

 [[0.13422092 0.12374959]]]
rmse: 0.1368
------------------------------------------------------------------
rmse mean: 0.1431
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.001
seed:0
[[[0.13543951 0.11690297]]

 [[0.42306414 0.39724909]]

 [[0.14877422 0.11677864]]

 [[0.09158658 0.07982159]]]
rmse: 0.1887
------------------------------------------------------------------
seed:1
[[[0.13541826 0.11689223]]

 [[0.17427409 0.14486011]]

 [[0.15174259 0.11790363]]

 [[0.09133258 0.07960826]]]
rmse: 0.1265
------------------------------------------------------------------
seed:2
[[[0.13520166 0.1166991 ]]

 [[0.17421957 0.1448081 ]]

 [[0.15240219 0.11797636]]

 [[0.04237062 0.0344632 ]]]
rmse: 0.1148
------------------------------------------------------------------
rmse mean: 0.1433
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=64, alpha=0.01
seed:0
[[[0.13559767 0.11723025]]

 [[0.14509601 0.1184405 ]]

 [[0.187309   0.15420629]]

 [[0.09158309 0.07980466]]]
rmse: 0.1287
------------------------------------------------------------------
seed:1
[[[0.13541714 0.11689049]]

 [[0.17385185 0.14446528]]

 [[0.15333028 0.119066  ]]

 [[0.09132053 0.07959333]]]
rmse: 0.1267
------------------------------------------------------------------
seed:2
[[[0.13518873 0.11668364]]

 [[0.1742266  0.14481492]]

 [[0.15237584 0.11796487]]

 [[0.05101287 0.04292557]]]
rmse: 0.1169
------------------------------------------------------------------
rmse mean: 0.1241
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.0
seed:0
[[[0.13488202 0.11642818]]

 [[0.17229639 0.14282172]]

 [[0.15262153 0.11807695]]

 [[0.09057127 0.07889091]]]
rmse: 0.1258
------------------------------------------------------------------
seed:1
[[[0.13541511 0.11688478]]

 [[0.17436217 0.14495359]]

 [[0.15260646 0.11803379]]

 [[0.09136837 0.07963851]]]
rmse: 0.1267
------------------------------------------------------------------
seed:2
[[[0.13523763 0.11673041]]

 [[0.17426528 0.14485688]]

 [[0.15257315 0.11801727]]

 [[0.09158351 0.07983211]]]
rmse: 0.1266
------------------------------------------------------------------
rmse mean: 0.1264
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.0001
seed:0
[[[0.13487844 0.11642546]]

 [[0.1723589  0.14288501]]

 [[0.15259272 0.11805221]]

 [[0.09020327 0.07820958]]]
rmse: 0.1257
------------------------------------------------------------------
seed:1
[[[0.13541534 0.11688488]]

 [[0.17435915 0.14495077]]

 [[0.15260614 0.11803211]]

 [[0.09134316 0.07961915]]]
rmse: 0.1267
------------------------------------------------------------------
seed:2
[[[0.13521481 0.11671084]]

 [[0.17426501 0.14485666]]

 [[0.15257311 0.11801711]]

 [[0.0916059  0.07986087]]]
rmse: 0.1266
------------------------------------------------------------------
rmse mean: 0.1263
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.001
seed:0
[[[0.13488271 0.11642938]]

 [[0.17155961 0.14208894]]

 [[0.15263921 0.11809209]]

 [[0.39263771 0.38269002]]]
rmse: 0.2014
------------------------------------------------------------------
seed:1
[[[0.13541513 0.11688463]]

 [[0.17436206 0.14495348]]

 [[0.15259595 0.11802154]]

 [[0.09134358 0.07961963]]]
rmse: 0.1266
------------------------------------------------------------------
seed:2
[[[0.13522361 0.11671863]]

 [[0.17426526 0.1448569 ]]

 [[0.15257475 0.11801879]]

 [[0.09158404 0.07983313]]]
rmse: 0.1266
------------------------------------------------------------------
rmse mean: 0.1516
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=128, alpha=0.01
seed:0
[[[0.13489136 0.11643657]]

 [[0.17236581 0.14286055]]

 [[0.15257437 0.1180363 ]]

 [[0.12254251 0.11030442]]]
rmse: 0.1338
------------------------------------------------------------------
seed:1
[[[0.1354157  0.1168853 ]]

 [[0.17436578 0.14495696]]

 [[0.15260037 0.11802623]]

 [[0.09135207 0.07962753]]]
rmse: 0.1267
------------------------------------------------------------------
seed:2
[[[0.13521458 0.11671079]]

 [[0.17426558 0.1448572 ]]

 [[0.1525752  0.11801782]]

 [[0.09160186 0.07984791]]]
rmse: 0.1266
------------------------------------------------------------------
rmse mean: 0.1290
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.0
seed:0
[[[0.13540474 0.11687546]]

 [[0.1742035  0.14480125]]

 [[0.15259417 0.11802556]]

 [[0.09166772 0.07991824]]]
rmse: 0.1267
------------------------------------------------------------------
seed:1
[[[0.13541283 0.1168401 ]]

 [[0.17430401 0.14489534]]

 [[1.19498653 1.19021664]]

 [[0.09167339 0.0799234 ]]]
rmse: 0.3910
------------------------------------------------------------------
seed:2
[[[0.13543584 0.11690189]]

 [[0.17429711 0.14488763]]

 [[0.21209466 0.18336104]]

 [[0.09168196 0.07993175]]]
rmse: 0.1423
------------------------------------------------------------------
rmse mean: 0.2200
===================================================================
K=64, lr=0.01, num_layers=2, hidden_dim=256, alpha=0.0001
seed:0
[[[0.13540624 0.11687691]]

 [[0.17483981 0.14539353]]

 [[0.15259469 0.11802602]]

 [[0.09166858 0.07991908]]]
rmse: 0.1268
------------------------------------------------------------------
seed:1
代码
文本

5. 总结

经过上面的内容可以看出,对于所选择的四组电池的数据,GRU模型优于Transformer模型,主要原因如下:

1. 数据特性

电池容量随放电周期的变化数据具有较强的时间依赖性和序列性。GRU(门控循环单元)专为处理这种时间序列数据而设计,能够有效捕捉长短期依赖关系。

2. 模型结构

GRU模型与Transformer模型在结构上有显著不同:

  • GRU模型:

    • 专为时间序列数据设计,能够有效处理序列数据中的时间依赖关系。
    • 由于其门机制,GRU能够捕捉长时间跨度的依赖关系,而不容易出现梯度消失问题。
    • 添加噪声有助于提高模型的鲁棒性,但不会过多复杂化模型。
  • Transformer模型:

    • 通常用于自然语言处理和图像处理任务,擅长处理自注意力机制。
    • 需要较大的数据量和计算资源来充分发挥其优势。
    • 可能过于复杂,对于相对简单的时间序列预测任务,可能出现过拟合或效果不佳的情况。

3. 模型复杂度与数据量

  • GRU模型:

    • 相对简单,参数较少,能够在小数据集上表现良好。
    • 更容易训练,收敛速度更快。
  • Transformer模型:

    • 结构复杂,参数较多,适合大数据集和更复杂的任务。
    • 对计算资源和数据量的需求较高,可能在小数据集上表现不如GRU。

4. 参数调整

虽然已经尽量对齐参数设置,但Transformer模型的参数调整比GRU更为复杂,可能需要更细致的调优。

5. 噪声和正则化

在GRU模型中添加噪声可能有助于提高模型的鲁棒性,而Transformer模型中这种处理可能没有同样的效果。


总结一下:

  • 数据特性: 电池容量预测数据具有较强的时间依赖性,GRU在这类任务上有天然优势。
  • 模型结构: GRU更简单直接,适合处理时间序列数据,而Transformer在处理时间序列数据上可能不如GRU高效。
  • 模型复杂度: GRU模型参数较少,更适合小数据集,而Transformer模型更复杂,可能导致过拟合或收敛缓慢。
  • 训练与正则化: 添加噪声有助于提高GRU模型的鲁棒性,但在Transformer模型中可能没有显著效果。

在这种特定任务和数据集下,GRU模型优于Transformer模型可能是因为其设计更适合时间序列数据,结构更简单,参数更少,并且在处理较小数据集时表现更佳。Transformer模型的复杂性和对大数据集的需求可能使其在这类任务中表现不如GRU。

性能提升建议 如果你希望进一步优化Transformer模型,可以尝试以下方法:

  1. 增加数据量: 如果可能,收集更多的数据。
  2. 优化参数: 进一步调优Transformer模型的参数,如层数、隐藏单元数、学习率等。
  3. 简化模型: 适当简化Transformer模型,减少参数数量。
  4. 数据增强: 使用数据增强技术,生成更多样本以提升模型的泛化能力。
代码
文本

6. 作业

在我们的数据集中有三十多组电池的数据,请大家使用这三十多组数据进行模型训练与预测,可从原始数据中选择百分之十的数据作为测试数据。模型还是选择GRU和Transformer,可以在扩充数据量后再对比两种模型的性能。

代码
文本
Deep Learning
transformer
中文
Deep Learningtransformer中文
已赞3
本文被以下合集收录
1-电池寿命预测
bohrbef131
更新于 2024-12-10
1 篇0 人关注
机器学习+材料筛选
微信用户DA_g
更新于 2024-11-25
10 篇0 人关注
{/**/}