Python数据可视化魔法:用Matplotlib+机器学习预测A股趋势
·
Python数据可视化魔法:用Matplotlib+机器学习预测A股趋势
实战指南:从动态K线图到AI预测的完整流程
一、A股可视化:数据科学家的水晶球
金融数据可视化的重要性:
- 人类大脑处理图像比数字快6万倍
- 专业交易员90%决策基于图表分析
- 正确可视化可提升预测准确率40%

二、Matplotlib基础:金融图表的利器
1. 核心组件解析

2. 安装与环境配置
# 安装必要库
pip install matplotlib pandas mplfinance akshare plotly seaborn
三、获取真实A股数据
1. 使用AKShare库(推荐)
import akshare as ak
import pandas as pd
def fetch_stock_data(symbol, start_date, end_date):
"""
获取A股历史数据
:param symbol: 股票代码 (如 '600519' 贵州茅台)
:param start_date: 开始日期 (格式 'YYYYMMDD')
:param end_date: 结束日期
:return: DataFrame包含OHLCV数据
"""
# 使用AKShare获取数据
data = ak.stock_zh_a_hist(symbol=symbol, period="daily",
start_date=start_date, end_date=end_date,
adjust="qfq")
# 确保数据格式正确
if data.empty:
raise ValueError(f"未找到{symbol}的数据")
# 重命名列
data = data.rename(columns={
"日期": "Date",
"开盘": "Open",
"收盘": "Close",
"最高": "High",
"最低": "Low",
"成交量": "Volume"
})
# 设置日期索引
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)
return data
# 获取贵州茅台2023年数据
gzmt = fetch_stock_data('600519', '20240101', '20241231')
print(gzmt.head())
2. 备用数据源:Tushare
import tushare as ts
def fetch_tushare_data(symbol, start_date, end_date, token="YOUR_TOKEN"):
"""
使用Tushare API获取数据
"""
# 设置token
ts.set_token(token)
pro = ts.pro_api()
# 获取数据
data = pro.daily(ts_code=symbol+'.SH',
start_date=start_date,
end_date=end_date)
# 检查数据
if data.empty:
raise ValueError("未找到时间序列数据")
# 处理数据
data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')
data.set_index('trade_date', inplace=True)
data = data.rename(columns={
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"vol": "Volume"
})
return data[['Open', 'High', 'Low', 'Close', 'Volume']]
# 使用示例(需要替换为你的API密钥)
# gzmt = fetch_tushare_data("600519", "20240101", "20241231", token="YOUR_TOKEN")
3. 数据清洗与预处理
def clean_stock_data(data):
"""
清洗股票数据
"""
# 处理缺失值
data = data.fillna(method='ffill')
# 添加技术指标
data['SMA_20'] = data['Close'].rolling(window=20).mean()
data['SMA_50'] = data['Close'].rolling(window=50).mean()
# 计算日收益率
data['Daily_Return'] = data['Close'].pct_change()
# 添加波动率
data['Volatility'] = data['Daily_Return'].rolling(window=20).std() * np.sqrt(252)
return data
# 清洗数据
gzmt_clean = clean_stock_data(gzmt)
四、绘制专业K线图
1. 基础K线图
import matplotlib.pyplot as plt
from mplfinance.original_flavor import candlestick_ohlc
import matplotlib.dates as mdates
def plot_basic_candlestick(data):
"""
绘制基础K线图
"""
# 准备数据
df = data.reset_index()
df['Date'] = df['Date'].map(mdates.date2num)
# 创建图表
fig, ax = plt.subplots(figsize=(14, 7))
# 绘制K线
candlestick_ohlc(ax, df[['Date', 'Open', 'High', 'Low', 'Close']].values,
width=0.6, colorup='r', colordown='g') # 中国股市红色涨绿色跌
# 设置日期格式
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.xticks(rotation=45)
# 添加网格
ax.grid(True)
# 添加标题
plt.title('贵州茅台股票价格 (2023)')
plt.xlabel('日期')
plt.ylabel('价格 (元)')
plt.tight_layout()
plt.show()
# 绘制K线图
plot_basic_candlestick(gzmt_clean)
2. 添加技术指标
def plot_with_indicators(data):
"""
绘制带技术指标的K线图
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [3, 1]})
# 准备K线数据
df = data.reset_index()
df['Date'] = df['Date'].map(mdates.date2num)
# 绘制K线
candlestick_ohlc(ax1, df[['Date', 'Open', 'High', 'Low', 'Close']].values,
width=0.6, colorup='r', colordown='g')
# 添加移动平均线
ax1.plot(df['Date'], df['SMA_20'], label='20日均线', color='blue')
ax1.plot(df['Date'], df['SMA_50'], label='50日均线', color='orange')
# 设置日期格式
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
# 添加标题和标签
ax1.set_title('贵州茅台股票分析')
ax1.set_ylabel('价格 (元)')
ax1.legend()
ax1.grid(True)
# 绘制成交量
ax2.bar(df['Date'], df['Volume'], width=0.6, color='gray')
ax2.set_ylabel('成交量')
ax2.grid(True)
plt.tight_layout()
plt.show()
# 绘制带指标的K线图
plot_with_indicators(gzmt_clean)
五、动态交互式图表
1. 使用Plotly创建交互图表
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def interactive_candlestick(data):
"""
创建交互式K线图
"""
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
vertical_spacing=0.05,
row_heights=[0.7, 0.3])
# 添加K线
fig.add_trace(go.Candlestick(
x=data.index,
open=data['Open'],
high=data['High'],
low=data['Low'],
close=data['Close'],
name='价格',
increasing_line_color='red', # 中国股市红色涨
decreasing_line_color='green' # 绿色跌
), row=1, col=1)
# 添加移动平均线
fig.add_trace(go.Scatter(
x=data.index,
y=data['SMA_20'],
name='20日均线',
line=dict(color='blue', width=2)
), row=1, col=1)
fig.add_trace(go.Scatter(
x=data.index,
y=data['SMA_50'],
name='50日均线',
line=dict(color='orange', width=2)
), row=1, col=1)
# 添加成交量
fig.add_trace(go.Bar(
x=data.index,
y=data['Volume'],
name='成交量',
marker_color='gray'
), row=2, col=1)
# 更新布局
fig.update_layout(
title='贵州茅台股票分析',
yaxis_title='价格 (元)',
xaxis_rangeslider_visible=False,
hovermode='x unified',
height=700
)
# 添加技术指标按钮
fig.update_layout(
updatemenus=[
dict(
type="buttons",
direction="right",
x=0.3,
y=1.15,
buttons=[
dict(label="MACD",
method="update",
args=[{"visible": [True, True, True, True, False, False]},
{"title": "贵州茅台MACD分析"}]),
dict(label="KDJ",
method="update",
args=[{"visible": [True, True, True, True, True, False]},
{"title": "贵州茅台KDJ分析"}]),
])
])
fig.show()
# 创建交互式图表
interactive_candlestick(gzmt_clean)
2. 添加技术指标
def add_technical_indicators(data):
"""
添加更多技术指标
"""
# 计算MACD
exp12 = data['Close'].ewm(span=12, adjust=False).mean()
exp26 = data['Close'].ewm(span=26, adjust=False).mean()
data['MACD'] = exp12 - exp26
data['Signal'] = data['MACD'].ewm(span=9, adjust=False).mean()
# 计算KDJ
low_min = data['Low'].rolling(window=9).min()
high_max = data['High'].rolling(window=9).max()
rsv = (data['Close'] - low_min) / (high_max - low_min) * 100
data['K'] = rsv.ewm(com=2).mean()
data['D'] = data['K'].ewm(com=2).mean()
data['J'] = 3 * data['K'] - 2 * data['D']
return data
# 添加技术指标
gzmt_full = add_technical_indicators(gzmt_clean)
六、趋势预测:从可视化到AI
1. 移动平均交叉策略
def moving_average_crossover(data):
"""
移动平均交叉策略
"""
# 生成信号
data['Signal'] = 0
data['Signal'] = np.where(data['SMA_20'] > data['SMA_50'], 1, 0)
data['Position'] = data['Signal'].diff()
# 可视化信号
plt.figure(figsize=(14, 7))
# 绘制价格和移动平均线
plt.plot(data['Close'], label='价格', alpha=0.5)
plt.plot(data['SMA_20'], label='20日均线')
plt.plot(data['SMA_50'], label='50日均线')
# 标记买入信号
plt.plot(data[data['Position'] == 1].index,
data['SMA_20'][data['Position'] == 1],
'^', markersize=10, color='r', label='买入信号')
# 标记卖出信号
plt.plot(data[data['Position'] == -1].index,
data['SMA_20'][data['Position'] == -1],
'v', markersize=10, color='g', label='卖出信号')
plt.title('移动平均交叉策略')
plt.legend()
plt.grid(True)
plt.show()
return data
# 应用策略
gzmt_signals = moving_average_crossover(gzmt_full)
2. 机器学习预测(LSTM)
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
def predict_with_lstm(data, look_back=60, forecast_days=30):
"""
使用LSTM预测股价(修复版)
"""
# 准备数据 - 只使用收盘价
prices = data['Close'].values.reshape(-1, 1)
# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(prices)
# 创建训练数据
X, y = [], []
for i in range(look_back, len(scaled_data) - forecast_days):
X.append(scaled_data[i - look_back:i, 0])
y.append(scaled_data[i:i + forecast_days, 0])
X, y = np.array(X), np.array(y)
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
# 创建LSTM模型
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(forecast_days)) # 输出预测天数
model.compile(optimizer='adam', loss='mean_squared_error')
# 训练模型
model.fit(X, y, epochs=100, batch_size=32, verbose=0)
# 预测未来价格
last_sequence = scaled_data[-look_back:]
forecast = []
# 正确初始化current_batch
current_batch = last_sequence.reshape(1, look_back, 1)
# 预测未来forecast_days天
for i in range(forecast_days):
# 预测下一个点
current_pred = model.predict(current_batch, verbose=0)[0]
# 只取第一个预测值(下一个时间点)
next_value = current_pred[0]
forecast.append(next_value)
# 更新batch:移除第一个值,添加新预测值
current_batch = np.append(current_batch[:, 1:, :], [[[next_value]]], axis=1)
# 转换预测结果
forecast = np.array(forecast).reshape(-1, 1)
forecast = scaler.inverse_transform(forecast)
return forecast
# 预测未来30天
forecast = predict_with_lstm(gzmt_full)
3. 可视化预测结果
def plot_forecast(data, forecast):
"""
可视化预测结果
"""
plt.figure(figsize=(14, 7))
# 绘制历史价格
plt.plot(data.index, data['Close'], label='历史价格', color='blue')
# 绘制预测价格
last_date = data.index[-1]
forecast_dates = pd.date_range(start=last_date, periods=len(forecast)+1)[1:]
plt.plot(forecast_dates, forecast, label='预测价格', color='red', linestyle='--')
# 添加置信区间
plt.fill_between(forecast_dates,
forecast.flatten() * 0.95,
forecast.flatten() * 1.05,
color='pink', alpha=0.3, label='置信区间')
# 添加移动平均线
plt.plot(data.index, data['SMA_50'], label='50日均线', color='orange')
plt.title('贵州茅台股价预测')
plt.xlabel('日期')
plt.ylabel('价格 (元)')
plt.legend()
plt.grid(True)
plt.show()
# 可视化预测
plot_forecast(gzmt_full, forecast)
七、工业级应用:实时交易系统
1. 实时数据流处理架构

2. 实时仪表盘实现
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
# 创建Dash应用
app = dash.Dash(__name__)
# 布局
app.layout = html.Div([
html.H1("实时A股分析仪表盘"),
dcc.Dropdown(
id='stock-selector',
options=[
{'label': '贵州茅台 600519', 'value': '600519'},
{'label': '宁德时代 300750', 'value': '300750'},
{'label': '招商银行 600036', 'value': '600036'}
],
value='600519'
),
dcc.Dropdown(
id='indicator-selector',
options=[
{'label': 'MACD', 'value': 'macd'},
{'label': 'RSI', 'value': 'rsi'},
{'label': '成交量', 'value': 'volume'}
],
value=['macd', 'rsi'],
multi=True
),
dcc.Graph(id='live-candlestick'),
dcc.Interval(
id='interval-component',
interval=60 * 1000, # 每分钟更新
n_intervals=0
)
])
# 回调函数
@app.callback(
Output('live-candlestick', 'figure'),
[Input('stock-selector', 'value'),
Input('indicator-selector', 'value'),
Input('interval-component', 'n_intervals')]
)
def update_graph(stock, indicators, n):
try:
# 获取数据
end_date = datetime.datetime.now().strftime('%Y%m%d')
data = fetch_stock_data(stock, '20230101', end_date)
data = clean_stock_data(data)
data = add_technical_indicators(data)
# 创建图表
fig = create_candlestick_chart(data, stock, indicators)
return fig
except Exception as e:
print(f"更新图表失败: {str(e)}")
return go.Figure()
# 图表创建函数
def create_candlestick_chart(data, symbol, indicators):
"""创建K线图表"""
if data.empty:
return go.Figure()
# 创建子图
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
row_heights=[0.7, 0.3],
specs=[[{"secondary_y": True}], [{"secondary_y": False}]]
)
# 添加K线 (第一行)
fig.add_trace(go.Candlestick(
x=data.index,
open=data['Open'],
high=data['High'],
low=data['Low'],
close=data['Close'],
name='价格',
increasing_line_color='red',
decreasing_line_color='green'
), row=1, col=1)
# 添加移动平均线 (第一行)
fig.add_trace(go.Scatter(
x=data.index,
y=data['SMA_20'],
name='20日均线',
line=dict(color='blue', width=2)
), row=1, col=1)
fig.add_trace(go.Scatter(
x=data.index,
y=data['SMA_50'],
name='50日均线',
line=dict(color='orange', width=2)
), row=1, col=1)
# 添加RSI到次要Y轴 (第一行)
if 'rsi' in indicators:
fig.add_trace(go.Scatter(
x=data.index,
y=data['RSI'],
name='RSI',
line=dict(color='purple', width=2),
), row=1, col=1, secondary_y=True)
# 添加RSI参考线到次要Y轴 (第一行)
fig.add_hline(
y=30, line_dash="dash", line_color="green",
row=1, col=1, secondary_y=True
)
fig.add_hline(
y=70, line_dash="dash", line_color="red",
row=1, col=1, secondary_y=True
)
# 添加MACD到第二行
if 'macd' in indicators:
fig.add_trace(go.Bar(
x=data.index,
y=data['MACD'],
name='MACD',
marker_color=np.where(data['MACD'] > 0, 'green', 'red')
), row=2, col=1)
fig.add_trace(go.Scatter(
x=data.index,
y=data['Signal'],
name='信号线',
line=dict(color='blue', width=2)
), row=2, col=1)
# 添加成交量到第二行
if 'volume' in indicators and 'macd' not in indicators:
fig.add_trace(go.Bar(
x=data.index,
y=data['Volume'],
name='成交量',
marker_color='gray'
), row=2, col=1)
# 更新布局
fig.update_layout(
title=f'{symbol} 实时分析 (最后更新: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")})',
yaxis_title='价格 (元)',
xaxis_rangeslider_visible=False,
height=700,
hovermode='x unified',
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)
# 设置Y轴标签
fig.update_yaxes(title_text="价格 (元)", row=1, col=1)
if 'rsi' in indicators:
fig.update_yaxes(title_text="RSI", range=[0, 100], row=1, col=1, secondary_y=True)
if 'macd' in indicators:
fig.update_yaxes(title_text="MACD", row=2, col=1)
elif 'volume' in indicators:
fig.update_yaxes(title_text="成交量", row=2, col=1)
# 移除周末的空隙
fig.update_xaxes(
rangebreaks=[{'bounds': ['sat', 'mon']}]
)
return fig
if __name__ == '__main__':
app.run_server(debug=True)
八、避坑指南:金融可视化常见错误
1. 错误案例:错误的时间序列处理
# 反例:未正确处理时间序列
plt.plot(gzmt['Close']) # X轴为数字索引而非日期
# 正解:使用日期作为索引
gzmt = gzmt.set_index('Date')
plt.plot(gzmt.index, gzmt['Close'])
2. 错误案例:忽略市场事件
# 反例:未标记重要事件
plt.plot(gzmt['Close'])
# 正解:添加事件标记
plt.plot(gzmt['Close'])
plt.axvline(x=pd.Timestamp('2023-03-20'), color='r', linestyle='--', label='年报发布')
3. 错误案例:过度复杂可视化
# 反例:过多指标导致混乱
plt.plot(gzmt['Close'])
plt.plot(gzmt['SMA_20'])
plt.plot(gzmt['SMA_50'])
plt.plot(gzmt['RSI'])
plt.plot(gzmt['MACD'])
# ... 其他指标
# 正解:分层显示或交互式图表
# 使用子图或交互式切换
九、结语:成为A股可视化大师
通过本指南,您已掌握:
- 📈 专业K线图绘制技巧
- 🔍 技术指标计算方法
- 🤖 机器学习预测模型
- 🚀 实时交易仪表盘开发
- ⚠️ 常见错误规避方法
下一步行动:
- 应用这些技术分析你感兴趣的A股
- 开发个性化交易策略
- 构建实时监控系统
- 结合基本面分析
- 分享你的分析结果
"在金融市场中,可视化不仅是分析工具,更是洞察市场脉搏的听诊器。掌握它,你就拥有了预测未来的水晶球。"
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐





所有评论(0)