Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
【数据科学导论】-数智教学案例-基于XGBoost的 股票预测
数据科学导论
数据科学导论
Mancn
更新于 2024-09-19
推荐镜像 :Basic Image:bohrium-notebook:2023-04-07
推荐机型 :c32_m64_cpu
2
1. 介绍 (Introduction)
1.1 案例背景
1.2 学习目标
2. 数据集介绍 (Dataset Introduction)
3. 导入数据与初步探索 (Importing Data and Initial Exploration)
4. 数据预处理 (Data Preprocessing)
5. 数据可视化 (Data Visualization)
5.1 收盘价随时间的变化 (Closing Price Over Time)
5.2 交易量与平均价格的关系 (Volume vs Average Price)
5.3 交易金额随时间的变化 (Trading Amount Over Time)
5.4 百分比变化的分布 (Distribution of Percentage Change)
5.5 百分比变化的直方图 (Histogram of Percentage Change)
6. 特征工程 (Feature Engineering)
6.1 滞后特征的构造与可视化 (Creating and Visualizing Lag Features)
6.2 移动均线的计算与可视化 (Moving Averages Visualization)
6.3 相对强弱指数 (RSI) 的计算与可视化 (RSI Calculation and Visualization)
6.4 MACD 指标的计算与可视化 (MACD Calculation and Visualization)
7. 建模与预测 (Modeling and Prediction)
7.1 数据预处理 (Data Preprocessing)
7.2 数据集拆分 (Train-Test Split)
7.3 网格搜索优化模型 (Grid Search to Optimize XGBoost Model)
7.4 使用最佳参数重新训练模型 (Train the Model with Best Parameters)
8. 结果评估 (Results Evaluation)
8.1 重要特征可视化 (Feature Importance Visualization)
8.2 打印模型的均方误差 (Mean Squared Error Calculation)
8.3 真实值与预测值的可视化 (Actual vs Predicted Prices)
9 总结 (Summary)

1. 介绍 (Introduction)

1.1 案例背景

本案例将通过分析中国黄金市场的数据,结合机器学习方法对股票价格进行预测。黄金作为一种避险资产,长期以来在金融市场中占有重要地位。通过对中国黄金市场的历史数据进行深入分析和价格预测,可以为投资者、政策制定者以及研究人员提供有价值的市场洞察。

随着大数据和机器学习技术的不断发展,量化金融分析已成为市场预测的重要工具。本案例将展示如何利用这些技术构建时间序列预测模型,不仅适用于黄金市场,还可以扩展应用到其他股票,例如茅台股票和中证500指数。通过学习本案例,您将掌握时间序列分析的基本概念,机器学习模型的应用,并了解如何进行数据的探索性分析、特征工程、建模、评估和优化。

1.2 学习目标

通过本案例,您将能够掌握以下技能:

  1. 数据探索与可视化 (Data Exploration and Visualization):
    • 了解如何加载和处理金融数据,并使用图表对数据进行探索性分析。
    • 掌握如何利用 seaborn 和 matplotlib 等工具进行时间序列数据的可视化,帮助识别市场趋势和数据特征。
  2. 特征工程 (Feature Engineering):
    • 学习如何为时间序列数据创建滞后特征、移动均线(SMA、EMA)等技术指标,用以提升预测模型的表现。
    • 探索如何进行数据清理和预处理,确保模型输入数据的准确性和一致性。
  3. 机器学习建模 (Machine Learning Modeling):
    • 掌握如何使用 XGBoost 等机器学习算法对股票价格进行预测,并通过网格搜索(GridSearchCV)优化模型超参数。
    • 了解如何在不同的数据集上进行模型训练、验证和测试。
  4. 模型评估与优化 (Model Evaluation and Optimization):
    • 掌握模型性能评估方法,如均方误差(MSE)等,并学习如何解释模型结果。
    • 了解如何通过特征重要性分析来识别影响模型预测的关键变量。
代码
文本

2. 数据集介绍 (Dataset Introduction)

本数据集来源于公开数据集,涵盖了2015年至2022年间中国黄金市场的每日交易数据。数据集包括以下变量:

  • trade_date: 交易日期
  • close: 收盘价
  • open: 开盘价
  • high: 最高价
  • low: 最低价
  • vol: 交易量

这些变量为我们提供了全面了解市场趋势的基础。接下来,我们将导入数据并进行探索性数据分析(EDA)。

代码
文本

3. 导入数据与初步探索 (Importing Data and Initial Exploration)

我们通过 head() 方法查看了数据集的前几行,数据中包括交易日期、收盘价、开盘价、交易量等信息。 这几行代码帮助我们了解数据的整体结构和变量的分布。

代码
文本
[35]
# ! pip install numpy pandas seaborn matplotlib scikit-learn xgboost plotly
代码
文本
[36]
# 导入所需的库
import numpy as np
import pandas as pd
import seaborn as sns
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error
from xgboost import XGBRegressor
import plotly.graph_objects as go
import plotly.offline as pyo

# 初始化 Plotly
pyo.init_notebook_mode()

# 读取数据集
df = pd.read_csv('Gold-Au99_95.csv')

# 查看前几行数据
df.head()
,
, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,
ts_codetrade_datecloseopenhighlowprice_avgchangepct_changevolamount
0Au99.9520221230409.93408.8410.20408.80409.931.2831.3232.013118000
1Au99.9520221229408.65408.6409.00408.35408.73-0.21-5.1462.025341500
2Au99.9520221228408.86410.8410.80408.85409.271.5036.8278.031923100
3Au99.9520221227407.36407.3407.50407.30407.330.338.1178.031771800
4Au99.9520221226407.03407.0407.08407.00407.020.6115.01172.069961920
,
代码
文本

4. 数据预处理 (Data Preprocessing)

在金融数据中,日期通常是最重要的因素之一。我们将交易日期从字符串转换为日期时间类型,并按照日期进行排序。为了简化分析,我们还将移除不必要的 ts_code 列。

代码
文本
[37]
# 转换日期数据类型并进行排序
df['trade_date'] = pd.to_datetime(df['trade_date'].astype(str), format='%Y%m%d')
df.sort_values(by="trade_date", inplace=True)

# 删除无关列
df.drop('ts_code', axis=1, inplace=True)

# 确认数据变化
df.head()
, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,
trade_datecloseopenhighlowprice_avgchangepct_changevolamount
19442015-01-05240.98239.00241.00238.80239.740.930.396256.01499814820
19432015-01-06242.30241.35242.50241.35241.871.320.553350.0810269520
19422015-01-07243.09243.50243.85243.00243.470.790.335904.01437456500
19412015-01-08242.52243.30243.35242.12242.71-0.57-0.235340.01296231300
19402015-01-09242.79241.90242.99241.90242.780.270.115816.01412011020
,
代码
文本

5. 数据可视化 (Data Visualization)

为了更好地理解数据趋势,我们将使用 seaborn 和 matplotlib 对不同特征进行可视化。

代码
文本

5.1 收盘价随时间的变化 (Closing Price Over Time)

这张折线图展示了股票的收盘价随时间的变化,帮助我们直观了解价格的波动趋势。

代码
文本
[38]
# 绘制收盘价随时间的变化
plt.figure(figsize=(12, 4), dpi=80)
sns.lineplot(data=df, x='trade_date', y='close')
plt.title('Closing Price Over Time')
plt.xlabel('Date')
plt.ylabel('Closing Price')
plt.show()
<Figure size 960x320 with 1 Axes>
代码
文本

5.2 交易量与平均价格的关系 (Volume vs Average Price)

通过这个回归图,我们可以查看交易量和收盘价之间是否存在相关性,进一步了解价格和市场行为之间的关系。

代码
文本
[39]
# 过滤掉交易量大于6000的数据
a = df[df['vol'] < 6000]

# 绘制交易量与平均价格的回归图
plt.figure(figsize=(16, 4), dpi=80)
sns.regplot(data=a, x='close', y='vol')
plt.title('Volume vs Closing Price')
plt.xlabel('Closing Price')
plt.ylabel('Volume')
plt.show()
<Figure size 1280x320 with 1 Axes>
代码
文本

5.3 交易金额随时间的变化 (Trading Amount Over Time)

这张图展示了交易金额随时间的变化,可以帮助我们了解市场的活跃程度和波动。

代码
文本
[40]
# 绘制交易金额随时间的变化图
plt.figure(figsize=(12, 4), dpi=80)
sns.lineplot(data=df, x='trade_date', y='amount')
plt.title('Trading Amount Over Time')
plt.xlabel('Date')
plt.ylabel('Amount')
plt.show()
<Figure size 960x320 with 1 Axes>
代码
文本

5.4 百分比变化的分布 (Distribution of Percentage Change)

箱线图展示了百分比变化的分布情况,包括数据的集中趋势和可能的异常值。

代码
文本
[41]
# 绘制百分比变化的箱线图
plt.figure(figsize=(12, 4), dpi=80)
sns.boxplot(data=df, x='pct_change')
plt.title('Boxplot of Percentage Change')
plt.xlabel('Percentage Change')
plt.show()
<Figure size 960x320 with 1 Axes>
代码
文本

5.5 百分比变化的直方图 (Histogram of Percentage Change)

直方图展示了百分比变化的分布,核密度估计 (KDE) 有助于平滑数据分布,提供更多的分布趋势信息。

代码
文本
[42]
# 绘制百分比变化的直方图,并添加核密度估计 (KDE)
plt.figure(figsize=(12, 4), dpi=80)
sns.histplot(data=df[(df['pct_change'] > -10) & (df['pct_change'] < 10)], x='pct_change', kde=True)
plt.title('Histogram of Percentage Change with KDE')
plt.xlabel('Percentage Change')
plt.show()
<Figure size 960x320 with 1 Axes>
代码
文本

6. 特征工程 (Feature Engineering)

接下来,我们通过构造一些特征(如移动均线、相对强弱指数等)来为机器学习模型准备输入。

这些特征通常在金融时间序列分析中非常有用,例如移动均线(SMA, EMA)帮助平滑价格波动,相对强弱指数(RSI)则用于判断超买和超卖。

代码
文本

6.1 滞后特征的构造与可视化 (Creating and Visualizing Lag Features)

滞后特征是在时间序列分析中常用的方法,用于分析过去几天的数据对当前数据的影响。这在构建预测模型时非常有用。

代码
文本
[28]
# 创建滞后特征
def lag_features(df, lags):
c = df.copy()
for lag in lags:
c[f'return_lag_{lag}'] = c['pct_change'].shift(lag)
c[f'vol_lag_{lag}'] = c['vol'].shift(lag)
return c

# 生成滞后特征
lags = [1, 2, 3]
a = lag_features(df, lags)

# 添加增量交易量特征
a['vol_incremental'] = a['vol_lag_1'] - a['vol_lag_2']

# 根据百分比变化创建标签 (1: 增长, 0: 下降)
a['label'] = a['pct_change'].apply(lambda x: 0 if x <= 0 else 1)
代码
文本

6.2 移动均线的计算与可视化 (Moving Averages Visualization)

移动平均线可以帮助我们平滑股票价格的波动,提供长期或短期趋势的视图。

代码
文本
[29]
# 计算不同时间的移动平均线 (SMA, EMA)
df['EMA_9'] = df['close'].ewm(9).mean().shift()
df['SMA_5'] = df['close'].rolling(5).mean().shift()
df['SMA_10'] = df['close'].rolling(10).mean().shift()
df['SMA_15'] = df['close'].rolling(15).mean().shift()
df['SMA_30'] = df['close'].rolling(30).mean().shift()

# 使用 Plotly 可视化这些移动平均线
t1 = go.Scatter(x=df['trade_date'], y=df['EMA_9'], name='EMA 9')
t2 = go.Scatter(x=df['trade_date'], y=df['SMA_5'], name='SMA 5')
t3 = go.Scatter(x=df['trade_date'], y=df['SMA_10'], name='SMA 10')
t4 = go.Scatter(x=df['trade_date'], y=df['SMA_15'], name='SMA 15')
t5 = go.Scatter(x=df['trade_date'], y=df['SMA_30'], name='SMA 30')
t6 = go.Scatter(x=df['trade_date'], y=df['close'], name='Close', opacity=0.2)

data = [t1, t2, t3, t4, t5, t6]
plt.close('all')
pyo.iplot(data, filename='basic-line')
代码
文本

6.3 相对强弱指数 (RSI) 的计算与可视化 (RSI Calculation and Visualization)

相对强弱指数 (RSI) 是技术分析中常用的指标之一,用来判断市场是否处于超买或超卖状态。

代码
文本
[30]
# 计算相对强弱指数 (RSI)
def relative_strength_idx(df, n=14):
close = df['close']
delta = close.diff()
delta = delta[1:]
gain = (delta.where(delta > 0, 0)).rolling(n).mean()
loss = (-delta.where(delta < 0, 0)).rolling(n).mean()
rs = gain / loss
rsi = 100.0 - (100.0 / (1.0 + rs))
return rsi

# 计算并添加 RSI 列
df['RSI'] = relative_strength_idx(df).fillna(0)

# 可视化 RSI
t1 = go.Scatter(x=df['trade_date'], y=df['RSI'], name='RSI')
data = [t1]
plt.close('all')
pyo.iplot(data, filename='basic-line')
代码
文本

6.4 MACD 指标的计算与可视化 (MACD Calculation and Visualization)

MACD 是用来识别价格趋势的反转信号的技术指标。通过计算和绘制 MACD 线和信号线,可以帮助我们分析市场的买入和卖出信号。

代码
文本
[31]
# 计算 MACD 指标
EMA_12 = pd.Series(df['close'].ewm(span=12, min_periods=12).mean())
EMA_26 = pd.Series(df['close'].ewm(span=26, min_periods=26).mean())
df['MACD'] = pd.Series(EMA_12 - EMA_26)
df['MACD_signal'] = pd.Series(df['MACD'].ewm(span=9, min_periods=9).mean())

# 可视化 MACD 和信号线
t1 = go.Scatter(x=df['trade_date'], y=df['close'], name='Close')
t2 = go.Scatter(x=df['trade_date'], y=EMA_12, name='EMA 12')
t3 = go.Scatter(x=df['trade_date'], y=EMA_26, name='EMA 26')
t4 = go.Scatter(x=df['trade_date'], y=df['MACD'], name='MACD')
t5 = go.Scatter(x=df['trade_date'], y=df['MACD_signal'], name='Signal Line')

data = [t1, t2, t3, t4, t5]
plt.close('all')
pyo.iplot(data, filename='basic-line')
代码
文本

7. 建模与预测 (Modeling and Prediction)

我们将使用 XGBoost 回归模型进行股票价格的预测。首先我们将数据集拆分为训练集和验证集,接着通过网格搜索(Grid Search)寻找最优模型参数。

代码
文本

7.1 数据预处理 (Data Preprocessing)

代码
文本
[32]
# 删除任何缺失值,确保数据完整性
df.dropna(how="any", inplace=True)

# 创建特征数据集,移除无关列
b = df.drop(['change', 'pct_change', 'amount', 'vol', 'high', 'open', 'low'], axis=1)

# 创建标签列,'label'表示下一个交易日的收盘价
b['label'] = b['close'].shift(-1)

# 删除含有空值的行(由于滞后特性引入的缺失值)
b.dropna(how="any", inplace=True)

# 将特征和标签分开
y = b['label'] # 标签是下一个交易日的收盘价
X = b.drop(columns=['label'], axis=1) # 特征是所有与标签无关的列
代码
文本

7.2 数据集拆分 (Train-Test Split)

我们将数据集拆分为训练集和验证集,通常使用 70% 的数据用于训练,30% 的数据用于验证。

代码
文本
[33]
# 拆分数据集,70% 数据用于训练,30% 用于验证
train_set, valid_set = np.split(b, [int(0.7 * len(b))])

# 从训练集中提取特征和标签
y_train = train_set['label']
X_train = train_set.drop(columns=['label', 'trade_date'], axis=1)

# 从验证集中提取特征和标签
y_valid = valid_set['label']
X_valid = valid_set.drop(columns=['label', 'trade_date'], axis=1)
/Users/mancn/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/numpy/_core/fromnumeric.py:57: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

代码
文本

7.3 网格搜索优化模型 (Grid Search to Optimize XGBoost Model)

为了找到最佳的模型参数,我们使用 GridSearchCV 进行网格搜索。XGBoost 模型的参数包括学习率(learning_rate)、树的数量(n_estimators)、树的最大深度(max_depth)、最小损失减少(gamma)等。

代码
文本
[34]
# 定义网格搜索的参数范围
grid = {
'n_estimators': [100, 200, 300, 400],
'learning_rate': [0.001, 0.005, 0.01, 0.05],
'max_depth': [8, 10, 12, 15],
'gamma': [0.001, 0.005, 0.01, 0.02],
'random_state': [42]
}

# 使用 GridSearchCV 进行网格搜索,寻找最佳参数组合
clf = GridSearchCV(estimator=XGBRegressor(),
param_grid=grid,
n_jobs=-1, # 使用所有可用CPU并行计算
cv=3,
verbose=3) # 3折交叉验证
clf.fit(X_train, y_train)

# 打印网格搜索找到的最佳参数和在验证集上的最佳得分
print(f'Best params: {clf.best_params_}')
print(f'Best validation score = {clf.best_score_}')
Fitting 3 folds for each of 256 candidates, totalling 768 fits
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=100, random_state=42;, score=-4.637 total time=   0.1s
[CV 2/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=100, random_state=42;, score=-10.386 total time=   0.1s
[CV 3/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=100, random_state=42;, score=-2.341 total time=   0.1s
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=200, random_state=42;, score=-4.090 total time=   0.2s
[CV 2/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=200, random_state=42;, score=-8.373 total time=   0.2s
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=300, random_state=42;, score=-3.619 total time=   0.3s
[CV 3/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=200, random_state=42;, score=-2.196 total time=   0.3s
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=10, n_estimators=100, random_state=42;, score=-4.637 total time=   0.1s
[CV 2/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=300, random_state=42;, score=-6.670 total time=   0.3s
[CV 2/3] END gamma=0.001, learning_rate=0.001, max_depth=10, n_estimators=100, random_state=42;, score=-10.386 total time=   0.1s
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=400, random_state=42;, score=-3.210 total time=   0.3s
[CV 3/3] END gamma=0.001, learning_rate=0.001, max_depth=10, n_estimators=100, random_state=42;, score=-2.341 total time=   0.2s
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=10, n_estimators=200, random_state=42;, score=-4.090 total time=   0.2s
[CV 3/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=300, random_state=42;, score=-2.069 total time=   0.4s
[CV 2/3] END gamma=0.001, learning_rate=0.001, max_depth=8, n_estimators=400, random_state=42;, score=-5.331 total time=   0.4s
[CV 2/3] END gamma=0.001, learning_rate=0.001, max_depth=10, n_estimators=200, random_state=42;, score=-8.373 total time=   0.3s
[CV 1/3] END gamma=0.001, learning_rate=0.001, max_depth=10, n_estimators=300, random_state=42;, score=-3.619 total time=   0.2s
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[34], line 16
     10 # 使用 GridSearchCV 进行网格搜索,寻找最佳参数组合
     11 clf = GridSearchCV(estimator=XGBRegressor(),
     12                    param_grid=grid,
     13                    n_jobs=-1,  # 使用所有可用CPU并行计算
     14                    cv=3,
     15                    verbose=3)  # 3折交叉验证
---> 16 clf.fit(X_train, y_train)
     18 # 打印网格搜索找到的最佳参数和在验证集上的最佳得分
     19 print(f'Best params: {clf.best_params_}')

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/sklearn/base.py:1473, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1466     estimator._validate_params()
   1468 with config_context(
   1469     skip_parameter_validation=(
   1470         prefer_skip_nested_validation or global_skip_validation
   1471     )
   1472 ):
-> 1473     return fit_method(estimator, *args, **kwargs)

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/sklearn/model_selection/_search.py:1019, in BaseSearchCV.fit(self, X, y, **params)
   1013     results = self._format_results(
   1014         all_candidate_params, n_splits, all_out, all_more_results
   1015     )
   1017     return results
-> 1019 self._run_search(evaluate_candidates)
   1021 # multimetric is determined here because in the case of a callable
   1022 # self.scoring the return type is only known after calling
   1023 first_test_score = all_out[0]["test_scores"]

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/sklearn/model_selection/_search.py:1573, in GridSearchCV._run_search(self, evaluate_candidates)
   1571 def _run_search(self, evaluate_candidates):
   1572     """Search all candidates in param_grid"""
-> 1573     evaluate_candidates(ParameterGrid(self.param_grid))

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/sklearn/model_selection/_search.py:965, in BaseSearchCV.fit.<locals>.evaluate_candidates(candidate_params, cv, more_results)
    957 if self.verbose > 0:
    958     print(
    959         "Fitting {0} folds for each of {1} candidates,"
    960         " totalling {2} fits".format(
    961             n_splits, n_candidates, n_candidates * n_splits
    962         )
    963     )
--> 965 out = parallel(
    966     delayed(_fit_and_score)(
    967         clone(base_estimator),
    968         X,
    969         y,
    970         train=train,
    971         test=test,
    972         parameters=parameters,
    973         split_progress=(split_idx, n_splits),
    974         candidate_progress=(cand_idx, n_candidates),
    975         **fit_and_score_kwargs,
    976     )
    977     for (cand_idx, parameters), (split_idx, (train, test)) in product(
    978         enumerate(candidate_params),
    979         enumerate(cv.split(X, y, **routed_params.splitter.split)),
    980     )
    981 )
    983 if len(out) < 1:
    984     raise ValueError(
    985         "No fits were performed. "
    986         "Was the CV iterator empty? "
    987         "Were there no candidates?"
    988     )

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/sklearn/utils/parallel.py:74, in Parallel.__call__(self, iterable)
     69 config = get_config()
     70 iterable_with_config = (
     71     (_with_config(delayed_func, config), args, kwargs)
     72     for delayed_func, args, kwargs in iterable
     73 )
---> 74 return super().__call__(iterable_with_config)

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/joblib/parallel.py:2007, in Parallel.__call__(self, iterable)
   2001 # The first item from the output is blank, but it makes the interpreter
   2002 # progress until it enters the Try/Except block of the generator and
   2003 # reaches the first `yield` statement. This starts the asynchronous
   2004 # dispatch of the tasks to the workers.
   2005 next(output)
-> 2007 return output if self.return_generator else list(output)

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/joblib/parallel.py:1650, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1647     yield
   1649     with self._backend.retrieval_context():
-> 1650         yield from self._retrieve()
   1652 except GeneratorExit:
   1653     # The generator has been garbage collected before being fully
   1654     # consumed. This aborts the remaining tasks if possible and warn
   1655     # the user if necessary.
   1656     self._exception = True

File ~/anaconda3/envs/20240918_163852-wdnotebook/lib/python3.10/site-packages/joblib/parallel.py:1762, in Parallel._retrieve(self)
   1757 # If the next job is not ready for retrieval yet, we just wait for
   1758 # async callbacks to progress.
   1759 if ((len(self._jobs) == 0) or
   1760     (self._jobs[0].get_status(
   1761         timeout=self.timeout) == TASK_PENDING)):
-> 1762     time.sleep(0.01)
   1763     continue
   1765 # We need to be careful: the job list can be filling up as
   1766 # we empty it and Python list are not thread-safe by
   1767 # default hence the use of the lock

KeyboardInterrupt: 
代码
文本

7.4 使用最佳参数重新训练模型 (Train the Model with Best Parameters)

使用网格搜索得到的最佳参数重新训练模型,并在验证集上进行预测。

代码
文本
[88]
# 使用网格搜索找到的最佳参数训练XGBoost模型
model = XGBRegressor(**clf.best_params_, objective='reg:squarederror')
model.fit(X_train, y_train)

# 在验证集上进行预测
pred = model.predict(X_valid)
代码
文本

8. 结果评估 (Results Evaluation)

训练完成后,我们可以使用模型对验证集进行预测,并计算均方误差(MSE)。

代码
文本

8.1 重要特征可视化 (Feature Importance Visualization)

通过 XGBoost 模型的特征重要性功能,可以看到哪些特征对预测结果影响最大。我们可以使用 xgboost 内置的 plot_importance 方法进行可视化。

代码
文本
[94]
# 导入并使用 plot_importance 函数绘制特征重要性
from xgboost import plot_importance

plt.figure(figsize=(10, 8))
plot_importance(model)
plt.title("Feature Importance")
plt.show()
<Figure size 1000x800 with 0 Axes>
<Figure size 640x480 with 1 Axes>
代码
文本

8.2 打印模型的均方误差 (Mean Squared Error Calculation)

为了评估模型的表现,我们计算验证集上的均方误差(MSE)。

代码
文本
[90]
# 计算验证集上的均方误差 (MSE)
from sklearn.metrics import mean_squared_error

mse = mean_squared_error(y_valid, pred)
print(f'Mean Squared Error = {mse}')
Mean Squared Error = 37.49102220049418
代码
文本

8.3 真实值与预测值的可视化 (Actual vs Predicted Prices)

最后,我们将预测的结果与实际结果进行对比,并可视化。

代码
文本
[91]
# 将预测结果与真实值进行对比
i = len(pred)
predicted_prices = df.tail(i).copy()
predicted_prices['predicted_close'] = pred

# 使用 Plotly 可视化实际值与预测值
t1 = go.Scatter(x=df['trade_date'], y=df['close'], name='Actual Closing Price', marker_color='LightSkyBlue')
t2 = go.Scatter(x=predicted_prices['trade_date'], y=predicted_prices['predicted_close'], name='Predicted Closing Price', marker_color='MediumPurple')

data = [t1, t2]
plt.close('all')
pyo.iplot(data, filename='actual-vs-predicted')
代码
文本

在这个可视化中,蓝线表示股票的实际收盘价,紫线表示模型的预测结果。通过观察预测结果与实际值的吻合程度,我们可以评估模型的效果。

代码
文本

9 总结 (Summary)

在本次 Notebook 中,我们使用了 XGBoost 模型来预测股票的收盘价。通过对数据的预处理,构建特征集,使用网格搜索优化模型参数,并可视化结果,我们展示了如何进行时间序列数据的预测分析。

代码
文本
数据科学导论
数据科学导论
点个赞吧
推荐阅读
公开
《量化投资》(二)量化投资策略:案例代码实现
量化交易
量化交易
chenwj
发布于 2023-09-27
1 转存文件
公开
文献案例|机器学习在构建材料相图中的应用
相图文献案例半监督学习
相图文献案例半监督学习
MileAway
发布于 2023-12-27
1 赞2 转存文件