Files
Kronos/examples/prediction_new_GUI.py
billconan2017 cf9744a46b 修改内容
2025-10-10 16:32:17 +08:00

1625 lines
64 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from datetime import datetime, timedelta
import warnings
import requests
import json
import time
import random
import akshare as ak
from typing import Dict, List, Tuple, Optional
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import threading
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
warnings.filterwarnings('ignore')
# 添加项目路径以便导入自定义模块
sys.path.append("../")
try:
from model import Kronos, KronosTokenizer, KronosPredictor
except ImportError:
print("⚠️ 无法导入Kronos模型预测功能将不可用")
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
class StockPredictorGUI:
"""股票预测图形界面"""
def __init__(self, root):
self.root = root
self.root.title("Kronos股票预测系统")
self.root.geometry("800x600")
self.root.configure(bg='#f0f0f0')
# 初始化市场分析器
self.market_analyzer = EnhancedMarketFactorAnalyzer()
# 创建界面
self.create_widgets()
# 默认配置
self.default_config = {
"stock_code": "600580",
"stock_name": "卧龙电驱",
"data_dir": r"D:\lianghuajiaoyi\Kronos\examples\data",
"output_dir": r"D:\lianghuajiaoyi\Kronos\examples\yuce",
"pred_days": 60,
"history_years": 1
}
def create_widgets(self):
"""创建界面组件"""
# 主标题
title_label = tk.Label(
self.root,
text="🤖 Kronos股票预测系统",
font=("Arial", 16, "bold"),
bg='#f0f0f0',
fg='#2c3e50'
)
title_label.pack(pady=10)
# 说明标签
desc_label = tk.Label(
self.root,
text="基于Kronos模型的多维度股票价格预测系统",
font=("Arial", 10),
bg='#f0f0f0',
fg='#7f8c8d'
)
desc_label.pack(pady=5)
# 创建主框架
main_frame = tk.Frame(self.root, bg='#f0f0f0')
main_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=10)
# 输入框架
input_frame = tk.LabelFrame(main_frame, text="股票参数设置", font=("Arial", 11, "bold"),
bg='#f0f0f0', fg='#2c3e50')
input_frame.pack(fill=tk.X, pady=10)
# 股票代码输入
tk.Label(input_frame, text="股票代码:", bg='#f0f0f0', font=("Arial", 10)).grid(row=0, column=0, sticky=tk.W,
padx=5, pady=5)
self.stock_code_var = tk.StringVar(value="600580")
stock_code_entry = tk.Entry(input_frame, textvariable=self.stock_code_var, font=("Arial", 10), width=15)
stock_code_entry.grid(row=0, column=1, padx=5, pady=5)
# 股票名称输入
tk.Label(input_frame, text="股票名称:", bg='#f0f0f0', font=("Arial", 10)).grid(row=0, column=2, sticky=tk.W,
padx=5, pady=5)
self.stock_name_var = tk.StringVar(value="卧龙电驱")
stock_name_entry = tk.Entry(input_frame, textvariable=self.stock_name_var, font=("Arial", 10), width=15)
stock_name_entry.grid(row=0, column=3, padx=5, pady=5)
# 预测天数
tk.Label(input_frame, text="预测天数:", bg='#f0f0f0', font=("Arial", 10)).grid(row=1, column=0, sticky=tk.W,
padx=5, pady=5)
self.pred_days_var = tk.StringVar(value="60")
pred_days_entry = tk.Entry(input_frame, textvariable=self.pred_days_var, font=("Arial", 10), width=15)
pred_days_entry.grid(row=1, column=1, padx=5, pady=5)
# 历史数据年限
tk.Label(input_frame, text="历史年限:", bg='#f0f0f0', font=("Arial", 10)).grid(row=1, column=2, sticky=tk.W,
padx=5, pady=5)
self.history_years_var = tk.StringVar(value="1")
history_years_entry = tk.Entry(input_frame, textvariable=self.history_years_var, font=("Arial", 10), width=15)
history_years_entry.grid(row=1, column=3, padx=5, pady=5)
# 目录设置框架
dir_frame = tk.LabelFrame(main_frame, text="目录设置", font=("Arial", 11, "bold"),
bg='#f0f0f0', fg='#2c3e50')
dir_frame.pack(fill=tk.X, pady=10)
# 数据目录
tk.Label(dir_frame, text="数据目录:", bg='#f0f0f0', font=("Arial", 10)).grid(row=0, column=0, sticky=tk.W,
padx=5, pady=5)
self.data_dir_var = tk.StringVar(value=r"D:\lianghuajiaoyi\Kronos\examples\data")
data_dir_entry = tk.Entry(dir_frame, textvariable=self.data_dir_var, font=("Arial", 10), width=40)
data_dir_entry.grid(row=0, column=1, padx=5, pady=5)
tk.Button(dir_frame, text="浏览", command=self.browse_data_dir, font=("Arial", 9)).grid(row=0, column=2, padx=5,
pady=5)
# 输出目录
tk.Label(dir_frame, text="输出目录:", bg='#f0f0f0', font=("Arial", 10)).grid(row=1, column=0, sticky=tk.W,
padx=5, pady=5)
self.output_dir_var = tk.StringVar(value=r"D:\lianghuajiaoyi\Kronos\examples\yuce")
output_dir_entry = tk.Entry(dir_frame, textvariable=self.output_dir_var, font=("Arial", 10), width=40)
output_dir_entry.grid(row=1, column=1, padx=5, pady=5)
tk.Button(dir_frame, text="浏览", command=self.browse_output_dir, font=("Arial", 9)).grid(row=1, column=2,
padx=5, pady=5)
# 功能按钮框架
button_frame = tk.Frame(main_frame, bg='#f0f0f0')
button_frame.pack(pady=20)
# 预测按钮
self.predict_button = tk.Button(
button_frame,
text="🚀 开始预测",
command=self.start_prediction,
font=("Arial", 12, "bold"),
bg='#3498db',
fg='white',
width=15,
height=2
)
self.predict_button.pack(side=tk.LEFT, padx=10)
# 重置按钮
reset_button = tk.Button(
button_frame,
text="🔄 重置",
command=self.reset_fields,
font=("Arial", 10),
bg='#95a5a6',
fg='white',
width=10,
height=2
)
reset_button.pack(side=tk.LEFT, padx=10)
# 退出按钮
exit_button = tk.Button(
button_frame,
text="❌ 退出",
command=self.root.quit,
font=("Arial", 10),
bg='#e74c3c',
fg='white',
width=10,
height=2
)
exit_button.pack(side=tk.LEFT, padx=10)
# 进度显示
self.progress_frame = tk.LabelFrame(main_frame, text="预测进度", font=("Arial", 11, "bold"),
bg='#f0f0f0', fg='#2c3e50')
self.progress_frame.pack(fill=tk.X, pady=10)
self.progress_var = tk.StringVar(value="等待开始预测...")
progress_label = tk.Label(self.progress_frame, textvariable=self.progress_var, bg='#f0f0f0',
font=("Arial", 10), wraplength=700, justify=tk.LEFT)
progress_label.pack(padx=10, pady=10, fill=tk.X)
# 进度条
self.progress_bar = ttk.Progressbar(self.progress_frame, mode='indeterminate')
self.progress_bar.pack(fill=tk.X, padx=10, pady=5)
# 结果展示区域
self.result_frame = tk.LabelFrame(main_frame, text="预测结果", font=("Arial", 11, "bold"),
bg='#f0f0f0', fg='#2c3e50')
self.result_frame.pack(fill=tk.BOTH, expand=True, pady=10)
self.result_text = tk.Text(self.result_frame, height=8, font=("Arial", 9), wrap=tk.WORD)
scrollbar = tk.Scrollbar(self.result_frame, command=self.result_text.yview)
self.result_text.configure(yscrollcommand=scrollbar.set)
self.result_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y, pady=5)
def browse_data_dir(self):
"""浏览数据目录"""
directory = filedialog.askdirectory()
if directory:
self.data_dir_var.set(directory)
def browse_output_dir(self):
"""浏览输出目录"""
directory = filedialog.askdirectory()
if directory:
self.output_dir_var.set(directory)
def reset_fields(self):
"""重置输入字段"""
self.stock_code_var.set("600580")
self.stock_name_var.set("卧龙电驱")
self.pred_days_var.set("60")
self.history_years_var.set("1")
self.data_dir_var.set(r"D:\lianghuajiaoyi\Kronos\examples\data")
self.output_dir_var.set(r"D:\lianghuajiaoyi\Kronos\examples\yuce")
self.result_text.delete(1.0, tk.END)
self.progress_var.set("等待开始预测...")
def start_prediction(self):
"""开始预测"""
# 验证输入
if not self.validate_inputs():
return
# 禁用预测按钮
self.predict_button.config(state=tk.DISABLED)
# 清空结果区域
self.result_text.delete(1.0, tk.END)
# 开始进度条
self.progress_bar.start()
# 在新线程中运行预测
prediction_thread = threading.Thread(target=self.run_prediction)
prediction_thread.daemon = True
prediction_thread.start()
def validate_inputs(self):
"""验证输入参数"""
try:
stock_code = self.stock_code_var.get().strip()
stock_name = self.stock_name_var.get().strip()
pred_days = int(self.pred_days_var.get())
history_years = int(self.history_years_var.get())
if not stock_code:
messagebox.showerror("错误", "请输入股票代码")
return False
if not stock_name:
messagebox.showerror("错误", "请输入股票名称")
return False
if pred_days <= 0 or pred_days > 365:
messagebox.showerror("错误", "预测天数应在1-365天之间")
return False
if history_years <= 0 or history_years > 10:
messagebox.showerror("错误", "历史年限应在1-10年之间")
return False
return True
except ValueError:
messagebox.showerror("错误", "请输入有效的数字")
return False
def run_prediction(self):
"""运行预测流程"""
try:
# 获取输入参数
stock_code = self.stock_code_var.get().strip()
stock_name = self.stock_name_var.get().strip()
pred_days = int(self.pred_days_var.get())
history_years = int(self.history_years_var.get())
data_dir = self.data_dir_var.get()
output_dir = self.output_dir_var.get()
# 更新进度
self.update_progress("🎯 开始股票预测流程...")
# 运行预测
success, result = run_comprehensive_prediction_gui(
stock_code, stock_name, data_dir, pred_days, output_dir, history_years,
progress_callback=self.update_progress,
result_callback=self.update_result
)
if success:
self.update_progress("✅ 预测完成!")
messagebox.showinfo("完成", f"{stock_name}({stock_code})预测完成!\n图表已保存到输出目录。")
else:
self.update_progress("❌ 预测失败")
messagebox.showerror("错误", f"预测失败: {result}")
except Exception as e:
self.update_progress(f"❌ 预测过程出现错误: {str(e)}")
messagebox.showerror("错误", f"预测过程出现错误: {str(e)}")
finally:
# 重新启用预测按钮
self.root.after(0, lambda: self.predict_button.config(state=tk.NORMAL))
# 停止进度条
self.root.after(0, self.progress_bar.stop)
def update_progress(self, message):
"""更新进度信息"""
self.root.after(0, lambda: self.progress_var.set(message))
print(message) # 同时在控制台输出
def update_result(self, message):
"""更新结果信息"""
self.root.after(0, lambda: self.result_text.insert(tk.END, message + "\n"))
self.root.after(0, lambda: self.result_text.see(tk.END))
# ==================== 基础数据获取函数 ====================
def ensure_output_directory(output_dir):
"""确保输出目录存在,如果不存在则创建"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"✅ 创建输出目录: {output_dir}")
return output_dir
def fetch_real_stock_data(stock_code, period="daily", adjust="qfq"):
"""
使用AKShare获取真实股票数据
"""
try:
print(f"📡 正在通过AKShare获取 {stock_code} 的真实股票数据...")
# 获取股票数据
df = ak.stock_zh_a_hist(symbol=stock_code, period=period, adjust=adjust)
if df is None or df.empty:
print(f"❌ 未获取到 {stock_code} 的数据")
return None
# 重命名列以统一格式
column_mapping = {
'日期': 'timestamps',
'开盘': 'open',
'收盘': 'close',
'最高': 'high',
'最低': 'low',
'成交量': 'volume',
'成交额': 'amount',
'振幅': 'amplitude',
'涨跌幅': 'pct_chg',
'涨跌额': 'change_amount',
'换手率': 'turnover'
}
# 只映射存在的列
actual_mapping = {k: v for k, v in column_mapping.items() if k in df.columns}
df = df.rename(columns=actual_mapping)
# 确保时间戳格式正确
df['timestamps'] = pd.to_datetime(df['timestamps'])
df = df.sort_values('timestamps').reset_index(drop=True)
# 添加股票代码列
df['stock_code'] = stock_code
print(f"✅ 成功获取 {len(df)} 条真实数据")
print(f"📈 最新收盘价: {df['close'].iloc[-1]:.2f}元, 涨跌幅: {df['pct_chg'].iloc[-1]:.2f}%")
print(f"📅 时间范围: {df['timestamps'].min()}{df['timestamps'].max()}")
return df
except Exception as e:
print(f"❌ AKShare数据获取失败: {e}")
return None
def get_stock_data_with_retry_all_history(stock_code="600580", retry_count=2):
"""
优化的数据获取函数 - 优先使用真实API数据
"""
print(f"🔄 尝试获取股票 {stock_code} 的真实历史数据...")
# 优先使用AKShare获取真实数据
df = fetch_real_stock_data(stock_code, "daily", "qfq")
if df is not None:
return df
else:
print("⚠️ 真实数据获取失败,使用基于真实价格的模拟数据...")
return create_realistic_fallback_data(stock_code)
def create_realistic_fallback_data(stock_code="600580"):
"""
基于真实价格的备用数据生成函数
"""
# 基于真实市场价格的参考数据
real_stock_references = {
'600580': {'name': '卧龙电驱', 'current_price': 15.20, 'range': (12.0, 20.0)},
'300207': {'name': '欣旺达', 'current_price': 33.79, 'range': (28.0, 38.0)},
'300418': {'name': '昆仑万维', 'current_price': 48.59, 'range': (40.0, 55.0)},
'002354': {'name': '天娱数科', 'current_price': 15.20, 'range': (12.0, 20.0)},
'000001': {'name': '平安银行', 'current_price': 12.50, 'range': (10.0, 16.0)},
'600036': {'name': '招商银行', 'current_price': 35.80, 'range': (30.0, 42.0)},
}
stock_info = real_stock_references.get(stock_code, {
'name': '未知股票',
'current_price': 20.0,
'range': (15.0, 25.0)
})
# 生成最近1年的交易日数据
end_date = datetime.now()
start_date = end_date - timedelta(days=365)
dates = pd.bdate_range(start=start_date, end=end_date, freq='B')
# 生成基于真实价格的价格序列
np.random.seed(42)
n_points = len(dates)
# 从当前价格反向生成历史价格
current_price = stock_info['current_price']
min_price, max_price = stock_info['range']
# 反向生成价格序列
prices = [current_price]
for i in range(1, n_points):
volatility = 0.02
historical_return = np.random.normal(-0.0002, volatility)
prev_price = prices[0] * (1 + historical_return)
prev_price = max(min_price * 0.9, min(max_price * 1.1, prev_price))
prices.insert(0, prev_price)
# 生成OHLC数据
stock_data = []
for i, date in enumerate(dates):
close_price = prices[i]
daily_volatility = abs(np.random.normal(0, 0.015))
open_price = close_price * (1 + np.random.normal(0, 0.005))
high_price = max(open_price, close_price) * (1 + daily_volatility)
low_price = min(open_price, close_price) * (1 - daily_volatility)
high_price = max(open_price, close_price, low_price, high_price)
low_price = min(open_price, close_price, high_price, low_price)
volume = int(abs(np.random.normal(1500000, 400000)))
amount = volume * close_price
if i > 0:
pct_chg = ((close_price - prices[i - 1]) / prices[i - 1]) * 100
change_amount = close_price - prices[i - 1]
else:
pct_chg = 0
change_amount = 0
stock_data.append({
'timestamps': date,
'stock_code': stock_code,
'open': round(open_price, 2),
'close': round(close_price, 2),
'high': round(high_price, 2),
'low': round(low_price, 2),
'volume': volume,
'amount': round(amount, 2),
'amplitude': round(((high_price - low_price) / open_price) * 100, 2),
'pct_chg': round(pct_chg, 2),
'change_amount': round(change_amount, 2),
'turnover': round(np.random.uniform(3.0, 8.0), 2)
})
df = pd.DataFrame(stock_data)
print(f"✅ 已生成基于真实价格的备用数据 {len(df)}")
return df
def save_all_history_stock_data(df, stock_code, save_dir):
"""
保存股票数据到指定目录
"""
if df is not None and not df.empty:
os.makedirs(save_dir, exist_ok=True)
csv_file = os.path.join(save_dir, f"{stock_code}_stock_data.csv")
df_reset = df.reset_index()
df_reset.to_csv(csv_file, encoding='utf-8-sig', index=False)
print(f"📁 股票数据已保存: {csv_file}")
return True
return False
def get_stock_data(stock_code, data_dir):
"""
获取股票数据如果数据文件不存在则从API获取真实数据
"""
csv_file_path = os.path.join(data_dir, f"{stock_code}_stock_data.csv")
if os.path.exists(csv_file_path):
print(f"📁 使用现有数据文件: {csv_file_path}")
return True, csv_file_path
else:
print(f"📡 数据文件不存在从API获取真实数据...")
df = get_stock_data_with_retry_all_history(stock_code)
if df is not None and not df.empty:
save_all_history_stock_data(df, stock_code, data_dir)
return True, csv_file_path
else:
print(f"❌ 无法获取股票数据")
return False, None
def prepare_stock_data(csv_file_path, stock_code, history_years=1):
"""
准备股票数据转换为Kronos模型需要的格式
"""
print(f"正在加载和预处理股票 {stock_code} 数据...")
# 读取CSV文件
df = pd.read_csv(csv_file_path, encoding='utf-8-sig')
# 标准化列名
column_mapping = {
'日期': 'timestamps',
'开盘价': 'open',
'最高价': 'high',
'最低价': 'low',
'收盘价': 'close',
'成交量': 'volume',
'成交额': 'amount',
'开盘': 'open',
'收盘': 'close',
'最高': 'high',
'最低': 'low'
}
actual_mapping = {k: v for k, v in column_mapping.items() if k in df.columns}
df = df.rename(columns=actual_mapping)
# 确保时间戳列存在并转换为datetime格式
if 'timestamps' not in df.columns:
if df.index.name == '日期':
df = df.reset_index()
df = df.rename(columns={'日期': 'timestamps'})
df['timestamps'] = pd.to_datetime(df['timestamps'])
df = df.sort_values('timestamps').reset_index(drop=True)
# 根据历史年限筛选数据
if history_years > 0:
cutoff_date = datetime.now() - timedelta(days=history_years * 365)
original_count = len(df)
df = df[df['timestamps'] >= cutoff_date]
print(f"📅 使用最近 {history_years} 年数据: {len(df)} 条记录 (从 {original_count} 条中筛选)")
# 数据验证
print(f"🔍 数据验证 - 最近5个交易日收盘价:")
recent_prices = df[['timestamps', 'close']].tail()
for _, row in recent_prices.iterrows():
print(f" {row['timestamps'].strftime('%Y-%m-%d')}: {row['close']:.2f}")
current_price = df['close'].iloc[-1]
print(f"✅ 数据加载完成,共 {len(df)} 条记录")
print(f"时间范围: {df['timestamps'].min()}{df['timestamps'].max()}")
print(f"价格范围: {df['close'].min():.2f} - {df['close'].max():.2f}")
print(f"当前价格: {current_price:.2f}")
return df
def calculate_prediction_parameters(df, target_days=60):
"""
根据目标预测天数计算合适的参数
"""
# 计算平均交易日数量
total_days = (df['timestamps'].max() - df['timestamps'].min()).days
trading_days = len(df)
trading_ratio = trading_days / total_days if total_days > 0 else 0.7
# 计算目标预测的交易日数量
pred_trading_days = int(target_days * trading_ratio)
# 设置回看期数
max_lookback = int(len(df) * 0.7)
lookback = min(pred_trading_days * 3, max_lookback, len(df) - pred_trading_days)
pred_len = min(pred_trading_days, len(df) - lookback)
# 确保参数在合理范围内
lookback = max(100, min(lookback, 400))
pred_len = max(20, min(pred_len, 120))
print(f"📊 参数计算:")
print(f" 目标预测天数: {target_days} 天(自然日)")
print(f" 预计交易日数量: {pred_trading_days}")
print(f" 回看期数 (lookback): {lookback}")
print(f" 预测期数 (pred_len): {pred_len}")
return lookback, pred_len
def generate_trading_dates_only(last_date, pred_len):
"""
🎯 修复版:只生成交易日,排除周末和法定节假日
"""
# 2025年法定节假日安排修正版
holidays_2025 = [
'2025-01-01', # 元旦
'2025-01-27', '2025-01-28', '2025-01-29', '2025-01-30', '2025-01-31', '2025-02-01', '2025-02-02', # 春节
'2025-04-04', '2025-04-05', '2025-04-06', # 清明
'2025-05-01', '2025-05-02', '2025-05-03', # 劳动节
'2025-06-08', '2025-06-09', '2025-06-10', # 端午
'2025-10-01', '2025-10-02', '2025-10-03', '2025-10-04', '2025-10-05', '2025-10-06', '2025-10-07', # 国庆节
]
holidays = [datetime.strptime(date, '%Y-%m-%d').date() for date in holidays_2025]
trading_dates = []
current_date = last_date + timedelta(days=1)
while len(trading_dates) < pred_len:
# 排除周末和节假日
if current_date.weekday() < 5 and current_date.date() not in holidays:
trading_dates.append(current_date)
current_date += timedelta(days=1)
print(f"📅 生成的纯交易日: 共 {len(trading_dates)}")
if trading_dates:
print(f" 起始: {trading_dates[0].strftime('%Y-%m-%d')}")
print(f" 结束: {trading_dates[-1].strftime('%Y-%m-%d')}")
return trading_dates
def calculate_optimal_interval(min_val, max_val):
"""
计算最优的Y轴刻度间隔
"""
range_val = max_val - min_val
if range_val <= 0:
return 1.0
if range_val < 1:
interval = 0.1
elif range_val < 5:
interval = 0.5
elif range_val < 10:
interval = 1.0
elif range_val < 20:
interval = 2.0
elif range_val < 50:
interval = 5.0
elif range_val < 100:
interval = 10.0
elif range_val < 200:
interval = 20.0
elif range_val < 500:
interval = 50.0
else:
interval = 100.0
return interval
# ==================== 增强版市场因素分析器 ====================
class EnhancedMarketFactorAnalyzer:
"""增强版市场因素分析器 - 整合更多维度的市场因素"""
def __init__(self):
self.market_data = {}
self.sector_data = {}
self.macro_factors = {}
self.policy_factors = {}
def analyze_market_trend(self, index_codes=["000001", "399001"]):
"""
分析大盘趋势 - 多指数综合分析
"""
try:
print(f"📊 综合分析大盘趋势...")
market_analysis = {}
for index_code in index_codes:
index_name = "上证指数" if index_code == "000001" else "深证成指"
print(f" 分析{index_name}({index_code})...")
# 获取指数数据
index_df = ak.stock_zh_index_hist(symbol=index_code, period="daily")
if index_df is None or index_df.empty:
print(f" ❌ 无法获取{index_name}数据")
continue
# 重命名列
index_df = index_df.rename(columns={
'日期': 'date', '收盘': 'close', '开盘': 'open',
'最高': 'high', '最低': 'low', '成交量': 'volume'
})
index_df['date'] = pd.to_datetime(index_df['date'])
index_df = index_df.sort_values('date').reset_index(drop=True)
# 计算技术指标
index_df['ma5'] = index_df['close'].rolling(5).mean()
index_df['ma20'] = index_df['close'].rolling(20).mean()
index_df['ma60'] = index_df['close'].rolling(60).mean()
index_df['vol_ma5'] = index_df['volume'].rolling(5).mean()
# 技术分析
current_data = index_df.iloc[-1]
prev_data = index_df.iloc[-2]
# 均线多头排列判断
ma_condition = (current_data['ma5'] > current_data['ma20'] > current_data['ma60'])
# 价格站在20日均线以上
price_above_ma20 = current_data['close'] > current_data['ma20']
# 成交量配合
volume_condition = current_data['volume'] > current_data['vol_ma5'] * 0.8
# 趋势强度
trend_strength = self._calculate_trend_strength(index_df)
is_main_uptrend = ma_condition and price_above_ma20 and trend_strength > 0.6
market_analysis[index_name] = {
'is_main_uptrend': is_main_uptrend,
'trend_strength': trend_strength,
'current_close': current_data['close'],
'price_change_pct': ((current_data['close'] - prev_data['close']) / prev_data['close']) * 100,
'market_status': '主升浪' if is_main_uptrend else '震荡调整'
}
# 综合判断
if market_analysis:
avg_trend_strength = np.mean([data['trend_strength'] for data in market_analysis.values()])
uptrend_count = sum(1 for data in market_analysis.values() if data['is_main_uptrend'])
overall_uptrend = uptrend_count >= len(market_analysis) * 0.5
final_analysis = {
'overall_is_main_uptrend': overall_uptrend,
'overall_trend_strength': avg_trend_strength,
'detailed_analysis': market_analysis,
'market_status': '主升浪' if overall_uptrend else '震荡调整'
}
print(f"✅ 大盘分析完成: {final_analysis['market_status']}, 综合趋势强度: {avg_trend_strength:.2f}")
return final_analysis
return self._get_default_market_analysis()
except Exception as e:
print(f"❌ 大盘分析错误: {e}")
return self._get_default_market_analysis()
def analyze_sector_resonance(self, stock_code):
"""
分析板块共振效应 - 增强版行业分析
"""
try:
print(f"🔄 分析板块共振效应...")
# 获取股票所属行业和概念
industry = "未知"
concepts = []
try:
stock_info = ak.stock_individual_info_em(symbol=stock_code)
if not stock_info.empty and 'value' in stock_info.columns:
industry_row = stock_info[stock_info['item'] == '行业']
if not industry_row.empty:
industry = industry_row['value'].iloc[0]
except:
pass
# 热门板块和概念映射
hot_sectors = {
'机器人': {'momentum': 0.85, 'limit_up_stocks': 18, 'active': True,
'description': '人形机器人、工业自动化'},
'半导体': {'momentum': 0.8, 'limit_up_stocks': 15, 'active': True, 'description': '芯片国产替代'},
'人工智能': {'momentum': 0.75, 'limit_up_stocks': 12, 'active': True, 'description': 'AI大模型、算力'},
'低空经济': {'momentum': 0.7, 'limit_up_stocks': 10, 'active': True, 'description': '无人机、eVTOL'},
'新能源': {'momentum': 0.6, 'limit_up_stocks': 8, 'active': True, 'description': '光伏、储能'},
'医药': {'momentum': 0.5, 'limit_up_stocks': 5, 'active': False, 'description': '创新药'}
}
# 判断当前股票所属热门板块
matched_sectors = []
for sector, data in hot_sectors.items():
if (sector in industry or
(stock_code == '600580' and sector in ['机器人', '低空经济']) or # 卧龙电驱特殊处理
(stock_code == '300207' and sector in ['新能源'])):
matched_sectors.append({
'sector': sector,
'momentum': data['momentum'],
'limit_up_stocks': data['limit_up_stocks'],
'is_active': data['active'],
'description': data['description']
})
# 计算综合共振分数
if matched_sectors:
resonance_score = np.mean([sector['momentum'] for sector in matched_sectors])
is_sector_hot = any(sector['is_active'] for sector in matched_sectors)
main_sector = max(matched_sectors, key=lambda x: x['momentum'])
else:
resonance_score = 0.5
is_sector_hot = False
main_sector = {'sector': '传统行业', 'momentum': 0.5, 'description': '无热门概念'}
analysis = {
'industry': industry,
'matched_sectors': matched_sectors,
'main_sector': main_sector,
'is_sector_hot': is_sector_hot,
'resonance_score': resonance_score,
'sector_count': len(matched_sectors)
}
print(f"✅ 板块分析完成: {industry}, 匹配{len(matched_sectors)}个热门板块, 共振分数: {resonance_score:.2f}")
return analysis
except Exception as e:
print(f"❌ 板块分析错误: {e}")
return self._get_default_sector_analysis()
def analyze_macro_factors(self):
"""
分析宏观因素 - 结合国内外政策
"""
try:
print(f"🌍 分析宏观因素...")
# 美国降息周期分析 - 基于最新信息
us_rate_analysis = {
'current_rate': 4.25, # 联邦基金利率目标区间4.00%-4.25%
'trend': '降息周期',
'recent_cut': '2025年9月降息25个基点',
'expected_cuts_2025': 2, # 市场预期2025年还有两次降息
'expected_cuts_2026': 2,
'impact_on_emerging_markets': 'positive',
'usd_index_support': 95.0, # 美元指数短期支撑位
'analysis': '美联储开启宽松周期,利好全球流动性'
}
# 国内政策因素 - 基于最新政策
domestic_policy = {
'monetary_policy': '稳健偏松',
'fiscal_policy': '积极财政',
'market_liquidity': '合理充裕',
'industrial_policy': '设备更新、以旧换新', # 大规模设备更新政策
'employment_policy': '稳就业政策加力', # 国务院稳就业政策
'analysis': '政策组合拳发力,经济稳中向好'
}
# 行业政策支持
industry_policy = {
'robot_policy': '机器人产业政策支持',
'chip_policy': '国产替代加速推进',
'AI_policy': '人工智能发展规划',
'low_altitude': '低空经济发展规划'
}
macro_analysis = {
'us_rate_cycle': us_rate_analysis,
'domestic_policy': domestic_policy,
'industry_policy': industry_policy,
'global_liquidity_outlook': '改善',
'overall_macro_score': 0.75 # 宏观环境整体偏积极
}
print(
f"✅ 宏观分析完成: 美国{us_rate_analysis['trend']}, 国内政策积极, 宏观评分: {macro_analysis['overall_macro_score']:.2f}")
return macro_analysis
except Exception as e:
print(f"❌ 宏观分析错误: {e}")
return self._get_default_macro_analysis()
def analyze_company_fundamentals(self, stock_code):
"""
分析公司基本面 - 针对特定股票
"""
try:
print(f"🏢 分析公司基本面...")
# 卧龙电驱特殊分析
if stock_code == '600580':
fundamentals = {
'company_name': '卧龙电驱',
'business_areas': ['工业电机', '机器人关键部件', '航空电机', '新能源汽车驱动'],
'recent_developments': [
'与智元机器人实现双向持股,推进具身智能机器人技术研发',
'成立浙江龙飞电驱,专注航空电机业务',
'发布AI外骨骼机器人及灵巧手',
'布局高爆发关节模组、伺服驱动器等人形机器人关键部件'
],
'growth_drivers': [
'设备更新政策推动工业电机需求',
'机器人产业快速发展',
'低空经济政策支持',
'出海战略加速'
],
'risk_factors': [
'机器人业务营收占比仅2.71%,占比较低',
'工业需求景气度波动',
'原料价格波动风险'
],
'investment_rating': '积极关注',
'fundamental_score': 0.7
}
else:
# 其他股票的基础分析
fundamentals = {
'company_name': '未知',
'business_areas': [],
'recent_developments': [],
'growth_drivers': [],
'risk_factors': [],
'investment_rating': '中性',
'fundamental_score': 0.5
}
print(f"✅ 基本面分析完成: {fundamentals['company_name']}, 评分: {fundamentals['fundamental_score']:.2f}")
return fundamentals
except Exception as e:
print(f"❌ 基本面分析错误: {e}")
return self._get_default_fundamental_analysis()
def _calculate_trend_strength(self, df):
"""计算趋势强度"""
if len(df) < 20:
return 0.5
ma_slope = (df['ma5'].iloc[-1] - df['ma5'].iloc[-20]) / df['ma5'].iloc[-20]
price_slope = (df['close'].iloc[-1] - df['close'].iloc[-20]) / df['close'].iloc[-20]
volume_trend = df['volume'].iloc[-5:].mean() / df['volume'].iloc[-10:-5].mean()
strength = (ma_slope * 0.4 + price_slope * 0.4 + min(volume_trend - 1, 0.2) * 0.2)
return max(0, min(1, strength * 10))
def _get_default_market_analysis(self):
return {
'overall_is_main_uptrend': False,
'overall_trend_strength': 0.5,
'market_status': '未知',
'detailed_analysis': {}
}
def _get_default_sector_analysis(self):
return {
'industry': '未知',
'matched_sectors': [],
'main_sector': {'sector': '未知', 'momentum': 0.5, 'description': ''},
'is_sector_hot': False,
'resonance_score': 0.5,
'sector_count': 0
}
def _get_default_macro_analysis(self):
return {
'us_rate_cycle': {'trend': '未知', 'expected_cuts_2025': 0},
'domestic_policy': {'monetary_policy': '中性'},
'overall_macro_score': 0.5
}
def _get_default_fundamental_analysis(self):
return {
'company_name': '未知',
'business_areas': [],
'recent_developments': [],
'growth_drivers': [],
'risk_factors': [],
'investment_rating': '中性',
'fundamental_score': 0.5
}
# ==================== 优化的预测平滑函数 ====================
def smooth_prediction_results(prediction_df, historical_df, smooth_factor=0.3):
"""
🎯 优化预测结果的平滑处理,避免剧烈波动
"""
print("🔄 应用预测结果平滑处理...")
smoothed_df = prediction_df.copy()
# 获取历史数据的趋势
recent_trend = calculate_recent_trend(historical_df)
# 对价格序列进行平滑
price_columns = ['close', 'open', 'high', 'low']
for col in price_columns:
if col in smoothed_df.columns:
original_values = smoothed_df[col].values
# 应用移动平均平滑
window_size = max(3, min(7, len(original_values) // 5))
smoothed_values = pd.Series(original_values).rolling(
window=window_size, center=True, min_periods=1
).mean()
# 结合历史趋势进行微调
trend_adjusted = smoothed_values * (1 + recent_trend * smooth_factor)
smoothed_df[col] = trend_adjusted.values
# 对成交量进行合理调整
if 'volume' in smoothed_df.columns:
hist_volume_mean = historical_df['volume'].tail(20).mean()
current_volume = smoothed_df['volume'].values
# 保持成交量在合理范围内
volume_factor = 0.8 + 0.4 * np.random.random(len(current_volume))
adjusted_volume = current_volume * volume_factor
# 确保成交量不会异常波动
volume_std = historical_df['volume'].tail(50).std()
volume_min = hist_volume_mean * 0.3
volume_max = hist_volume_mean * 3.0
smoothed_df['volume'] = np.clip(adjusted_volume, volume_min, volume_max)
print("✅ 预测结果平滑完成")
return smoothed_df
def calculate_recent_trend(historical_df, lookback_days=20):
"""
计算近期价格趋势
"""
if len(historical_df) < lookback_days:
lookback_days = len(historical_df)
recent_prices = historical_df['close'].tail(lookback_days).values
if len(recent_prices) < 2:
return 0
# 计算线性回归斜率作为趋势
x = np.arange(len(recent_prices))
slope = np.polyfit(x, recent_prices, 1)[0]
# 归一化为趋势强度 (-1 到 1)
price_range = np.ptp(recent_prices)
if price_range > 0:
trend_strength = slope / price_range * len(recent_prices)
else:
trend_strength = 0
return np.clip(trend_strength, -0.1, 0.1) # 限制趋势强度
def apply_post_holiday_adjustment(prediction_df, future_dates, holiday_periods):
"""
🎯 修复版:应用节后调整,避免国庆后异常下跌
"""
print("🔄 应用节后日历效应调整...")
adjusted_df = prediction_df.copy()
for holiday in holiday_periods:
holiday_start = pd.Timestamp(holiday['start'])
holiday_end = pd.Timestamp(holiday['end'])
adjustment_days = holiday['adjustment_days']
effect_strength = holiday['effect_strength']
# 计算调整期结束日期
adjustment_end = holiday_end + timedelta(days=adjustment_days)
# 找到在节后调整期内的日期索引
post_holiday_indices = []
for i, date in enumerate(future_dates):
if holiday_end <= date < adjustment_end:
post_holiday_indices.append(i)
# 应用节后效应调整
if post_holiday_indices:
for col in ['close', 'open', 'high', 'low']:
if col in adjusted_df.columns:
for idx in post_holiday_indices:
adjusted_df.iloc[idx][col] = adjusted_df.iloc[idx][col] * (1 + effect_strength)
print("✅ 节后调整完成")
return adjusted_df
# ==================== 价格合理性检查函数 ====================
def validate_prediction_results(historical_df, prediction_df, max_price_change=0.3):
"""
🎯 验证预测结果的合理性,避免异常价格波动
"""
print("🔍 验证预测结果合理性...")
validated_df = prediction_df.copy()
current_price = historical_df['close'].iloc[-1]
# 检查价格列的合理性
price_columns = ['close', 'open', 'high', 'low']
for col in price_columns:
if col in validated_df.columns:
# 计算最大允许的价格变化范围
max_allowed_change = current_price * max_price_change
# 检查每个预测价格
for i in range(len(validated_df)):
predicted_price = validated_df[col].iloc[i]
# 如果预测价格超出合理范围,进行修正
if abs(predicted_price - current_price) > max_allowed_change:
# 基于历史波动率进行修正
correction_factor = 0.8 + 0.4 * np.random.random()
corrected_price = current_price * (1 + (predicted_price / current_price - 1) * correction_factor)
validated_df.iloc[i][col] = corrected_price
print(f"⚠️ 修正异常{col}价格: {predicted_price:.2f} -> {corrected_price:.2f}")
print("✅ 预测结果验证完成")
return validated_df
# ==================== GUI版本预测函数 ====================
def run_comprehensive_prediction_gui(stock_code, stock_name, data_dir, pred_days, output_dir, history_years=1,
progress_callback=None, result_callback=None):
"""
GUI版本的预测函数
"""
def update_progress(message):
if progress_callback:
progress_callback(message)
print(message)
def update_result(message):
if result_callback:
result_callback(message)
print(message)
try:
# 初始化市场分析器
market_analyzer = EnhancedMarketFactorAnalyzer()
update_progress(f"🎯 开始 {stock_name}({stock_code}) 预测流程")
update_progress("=" * 50)
# 1. 获取数据
update_progress("\n步骤1: 获取股票数据...")
success, csv_file_path = get_stock_data(stock_code, data_dir)
if not success:
update_result("❌ 无法获取股票数据,预测终止")
return False, "无法获取股票数据"
# 2. 加载模型和分词器
update_progress("\n步骤2: 加载Kronos模型和分词器...")
try:
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
update_progress("✅ 模型加载完成 - 使用Kronos-base模型")
except Exception as e:
error_msg = f"❌ 模型加载失败: {e}"
update_result(error_msg)
update_progress("⚠️ 预测功能不可用,请检查模型安装")
return False, error_msg
# 3. 实例化预测器
update_progress("步骤3: 初始化预测器...")
predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
update_progress("✅ 预测器初始化完成")
# 4. 准备数据
update_progress("步骤4: 准备股票数据...")
df = prepare_stock_data(csv_file_path, stock_code, history_years)
# 5. 计算预测参数
update_progress("步骤5: 计算预测参数...")
lookback, pred_len = calculate_prediction_parameters(df, target_days=pred_days)
if pred_len <= 0:
update_result("❌ 数据量不足,无法进行预测")
return False, "数据量不足"
update_progress(f"✅ 最终参数 - 回看期: {lookback}, 预测期: {pred_len}")
# 6. 准备输入数据
update_progress("步骤6: 准备输入数据...")
x_df = df.loc[-lookback:, ['open', 'high', 'low', 'close', 'volume', 'amount']].reset_index(drop=True)
x_timestamp = df.loc[-lookback:, 'timestamps'].reset_index(drop=True)
# 生成未来日期 - 🎯 修复:只生成交易日
last_historical_date = df['timestamps'].iloc[-1]
future_dates = generate_trading_dates_only(last_historical_date, pred_len)
if len(future_dates) < pred_len:
update_progress(f"⚠️ 警告:只生成了 {len(future_dates)} 个交易日,少于请求的 {pred_len}")
pred_len = len(future_dates)
update_progress(f"输入数据形状: {x_df.shape}")
update_progress(f"历史数据时间范围: {x_timestamp.iloc[0]}{x_timestamp.iloc[-1]}")
if future_dates:
update_progress(f"预测时间范围: {future_dates[0]}{future_dates[-1]}")
# 7. 执行基础预测
update_progress("步骤7: 执行基础价格预测...")
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=pd.Series(future_dates),
pred_len=pred_len,
T=1.0,
top_p=0.9,
sample_count=1,
verbose=True
)
update_progress("✅ 基础预测完成")
# 🎯 新增:对基础预测进行合理性检查
update_progress("步骤7.2: 验证预测结果合理性...")
historical_df_for_validation = df.loc[-lookback:].reset_index(drop=True)
validated_pred_df = validate_prediction_results(historical_df_for_validation, pred_df)
# 🎯 新增:对基础预测进行平滑处理
update_progress("步骤7.5: 对预测结果进行平滑优化...")
smoothed_pred_df = smooth_prediction_results(validated_pred_df, historical_df_for_validation)
# 🎯 修复:应用节后调整(特别是国庆节后)
holiday_periods = [
{
'start': '2025-10-01',
'end': '2025-10-09', # 国庆后第一个交易日10月9日周四
'adjustment_days': 5,
'effect_strength': 0.03 # 节后通常有正面效应
}
]
adjusted_pred_df = apply_post_holiday_adjustment(smoothed_pred_df, future_dates, holiday_periods)
# 8. 使用多维度市场因素增强预测
update_progress("步骤8: 应用多维度市场因素增强预测...")
enhanced_pred_df, enhancement_info = enhance_prediction_with_market_factors(
df.loc[-lookback:].reset_index(drop=True),
adjusted_pred_df, # 使用平滑调整后的预测结果
stock_code,
market_analyzer
)
# 将增强预测结果添加到信息中
enhancement_info['enhanced_prediction'] = enhanced_pred_df
# 9. 创建综合市场分析报告
update_progress("步骤9: 创建市场分析报告...")
market_report = create_comprehensive_market_report(enhancement_info, output_dir, stock_code)
# 10. 生成预测图表
update_progress("步骤10: 生成预测图表...")
historical_df = df.loc[-lookback:].reset_index(drop=True)
chart_path = plot_optimized_prediction_gui(
historical_df, adjusted_pred_df, enhanced_pred_df, future_dates,
stock_code, stock_name, output_dir, enhancement_info
)
# 11. 生成预测报告
update_progress("步骤11: 生成预测报告...")
if len(enhanced_pred_df) > 0:
current_price = historical_df['close'].iloc[-1]
base_predicted_price = adjusted_pred_df['close'].iloc[-1] if len(adjusted_pred_df) > 0 else current_price
enhanced_predicted_price = enhanced_pred_df['close'].iloc[-1]
base_change_pct = (base_predicted_price / current_price - 1) * 100
enhanced_change_pct = (enhanced_predicted_price / current_price - 1) * 100
# 输出预测结果
update_result(f"\n📈 {stock_name}({stock_code}) 预测报告")
update_result("=" * 50)
update_result(f"当前价格: {current_price:.2f}")
update_result(f"平滑预测价格: {base_predicted_price:.2f} 元 ({base_change_pct:+.2f}%)")
update_result(f"增强预测价格: {enhanced_predicted_price:.2f} 元 ({enhanced_change_pct:+.2f}%)")
update_result(f"市场因素调整因子: {enhancement_info['adjustment_factor']:.4f}")
update_result(f"大盘状态: {enhancement_info['market_analysis']['market_status']}")
update_result(f"板块共振: {enhancement_info['sector_analysis']['main_sector']['sector']}")
update_result(f"宏观环境: 美国{enhancement_info['macro_analysis']['us_rate_cycle']['trend']}")
update_result(f"公司评级: {enhancement_info['fundamental_analysis']['investment_rating']}")
# 保存详细预测数据
prediction_details = pd.DataFrame({
'日期': future_dates,
'平滑预测收盘价': adjusted_pred_df['close'].values if len(
adjusted_pred_df) > 0 else [current_price] * len(future_dates),
'增强预测收盘价': enhanced_pred_df['close'].values,
'预测成交量': enhanced_pred_df['volume'].values
})
prediction_file = os.path.join(output_dir, f'{stock_code}_comprehensive_predictions.csv')
prediction_details.to_csv(prediction_file, index=False, encoding='utf-8-sig')
update_progress(f"💾 详细预测数据已保存: {prediction_file}")
update_progress(f"\n🎉 {stock_name}({stock_code}) 预测完成!")
update_progress(f"📊 预测图表: {chart_path}")
return True, "预测完成"
except Exception as e:
error_msg = f"❌ 预测过程中出现错误: {e}"
update_result(error_msg)
import traceback
traceback.print_exc()
return False, error_msg
def enhance_prediction_with_market_factors(historical_df, prediction_df, stock_code, market_analyzer):
"""
使用市场因素增强预测结果
"""
print("\n🎯 使用市场因素增强预测...")
# 获取各类市场分析
market_analysis = market_analyzer.analyze_market_trend()
sector_analysis = market_analyzer.analyze_sector_resonance(stock_code)
macro_analysis = market_analyzer.analyze_macro_factors()
fundamental_analysis = market_analyzer.analyze_company_fundamentals(stock_code)
# 计算综合调整因子
adjustment_factor = calculate_enhanced_adjustment_factor(
market_analysis, sector_analysis, macro_analysis, fundamental_analysis
)
print(f"📈 综合调整因子: {adjustment_factor:.4f}")
# 应用调整到预测结果
enhanced_prediction = prediction_df.copy()
# 对价格预测进行调整
price_columns = ['close', 'open', 'high', 'low']
for col in price_columns:
if col in enhanced_prediction.columns:
enhanced_prediction[col] = enhanced_prediction[col] * adjustment_factor
# 对成交量进行调整
if 'volume' in enhanced_prediction.columns:
volume_adjustment = 1 + (adjustment_factor - 1) * 0.3
enhanced_prediction['volume'] = enhanced_prediction['volume'] * volume_adjustment
return enhanced_prediction, {
'market_analysis': market_analysis,
'sector_analysis': sector_analysis,
'macro_analysis': macro_analysis,
'fundamental_analysis': fundamental_analysis,
'adjustment_factor': adjustment_factor
}
def calculate_enhanced_adjustment_factor(market_analysis, sector_analysis, macro_analysis, fundamental_analysis):
"""
计算基于多维度市场因素的调整因子
"""
base_factor = 1.0
# 1. 大盘趋势影响 (权重25%)
if market_analysis['overall_is_main_uptrend']:
trend_strength = market_analysis['overall_trend_strength']
base_factor *= (1 + trend_strength * 0.08)
else:
trend_strength = market_analysis['overall_trend_strength']
base_factor *= (1 + (trend_strength - 0.5) * 0.04)
# 2. 板块共振影响 (权重25%)
resonance_score = sector_analysis['resonance_score']
sector_count = sector_analysis['sector_count']
if sector_analysis['is_sector_hot']:
base_factor *= (1 + resonance_score * 0.06 + min(sector_count * 0.01, 0.03))
else:
base_factor *= (1 + (resonance_score - 0.5) * 0.02)
# 3. 宏观因素影响 (权重20%)
macro_score = macro_analysis['overall_macro_score']
base_factor *= (1 + (macro_score - 0.5) * 0.06)
# 4. 美国降息周期特殊影响 (权重10%)
us_rate_trend = macro_analysis['us_rate_cycle']['trend']
if us_rate_trend == '降息周期':
expected_cuts = macro_analysis['us_rate_cycle']['expected_cuts_2025']
base_factor *= (1 + expected_cuts * 0.015)
# 5. 公司基本面影响 (权重20%)
fundamental_score = fundamental_analysis['fundamental_score']
base_factor *= (1 + (fundamental_score - 0.5) * 0.08)
# 🎯 限制调整幅度在更合理范围内 (0.9 ~ 1.1),避免过度调整
return max(0.9, min(1.1, base_factor))
def create_comprehensive_market_report(enhancement_info, output_dir, stock_code):
"""
创建综合市场分析报告
"""
report = {
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'stock_code': stock_code,
'market_analysis': enhancement_info['market_analysis'],
'sector_analysis': enhancement_info['sector_analysis'],
'macro_analysis': enhancement_info['macro_analysis'],
'fundamental_analysis': enhancement_info['fundamental_analysis'],
'adjustment_factor': enhancement_info['adjustment_factor']
}
# 保存报告
report_file = os.path.join(output_dir, f'{stock_code}_comprehensive_analysis_report.json')
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"📋 综合分析报告已保存: {report_file}")
return report
def plot_optimized_prediction_gui(historical_df, base_pred_df, enhanced_pred_df, future_trading_dates,
stock_code, stock_name, output_dir, enhancement_info=None):
"""
🎯 优化版:清晰显示每个交易日的预测图表
"""
ensure_output_directory(output_dir)
# 设置配色
colors = {
'historical': '#1f77b4',
'prediction': '#ff7f0e',
'enhanced': '#2ca02c',
'background': '#f8f9fa',
'grid': '#e9ecef'
}
# 创建图表
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle(f'{stock_name}({stock_code}) - 优化版交易日预测图表', fontsize=16, fontweight='bold')
# 设置背景色
fig.patch.set_facecolor('white')
for ax in [ax1, ax2, ax3, ax4]:
ax.set_facecolor(colors['background'])
# 🎯 优化1: 使用实际日期作为x轴但只显示交易日
all_dates = list(historical_df['timestamps']) + future_trading_dates
# 1. 主价格图表
current_price = historical_df['close'].iloc[-1]
# 绘制历史价格
ax1.plot(historical_df['timestamps'], historical_df['close'],
color=colors['historical'], linewidth=2.5, label='历史价格')
# 绘制预测价格
if len(future_trading_dates) > 0:
# 绘制基础预测
ax1.plot(future_trading_dates, base_pred_df['close'],
color=colors['prediction'], linewidth=2, label='平滑预测', linestyle='--')
# 绘制增强预测
ax1.plot(future_trading_dates, enhanced_pred_df['close'],
color=colors['enhanced'], linewidth=2.5, label='增强预测')
# 🎯 修复:使用更安全的关键日期标记
mark_key_dates_safe(ax1, future_trading_dates, enhanced_pred_df)
ax1.set_ylabel('收盘价 (元)', fontsize=12, fontweight='bold')
ax1.legend(loc='upper left', fontsize=10)
ax1.grid(True, color=colors['grid'], alpha=0.7)
ax1.set_title(f'价格走势预测 - 当前价: {current_price:.2f}', fontweight='bold', fontsize=13)
# 🎯 优化2: 使用每周标记,避免过于密集
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
ax1.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO, interval=2)) # 每两周一个标记
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, fontsize=9)
# 2. 成交量图表
ax2.bar(historical_df['timestamps'], historical_df['volume'],
alpha=0.6, color=colors['historical'], label='历史成交量')
if len(future_trading_dates) > 0:
ax2.bar(future_trading_dates, enhanced_pred_df['volume'],
alpha=0.6, color=colors['enhanced'], label='预测成交量')
ax2.set_ylabel('成交量', fontsize=12, fontweight='bold')
ax2.legend(loc='upper left', fontsize=10)
ax2.grid(True, color=colors['grid'], alpha=0.7)
ax2.set_title('成交量预测', fontweight='bold', fontsize=13)
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
ax2.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO, interval=2))
plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, fontsize=9)
# 3. 价格变化率图表
ax3.plot(historical_df['timestamps'], historical_df['close'].pct_change() * 100,
color=colors['historical'], linewidth=1.5, label='历史涨跌幅', alpha=0.7)
if len(future_trading_dates) > 0:
pred_returns = enhanced_pred_df['close'].pct_change() * 100
ax3.plot(future_trading_dates, pred_returns,
color=colors['enhanced'], linewidth=2, label='预测涨跌幅')
# 添加零线参考
ax3.axhline(y=0, color='red', linestyle='-', alpha=0.3, linewidth=1)
ax3.set_ylabel('日涨跌幅 (%)', fontsize=12, fontweight='bold')
ax3.legend(loc='upper left', fontsize=10)
ax3.grid(True, color=colors['grid'], alpha=0.7)
ax3.set_title('价格变化率分析', fontweight='bold', fontsize=13)
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
ax3.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO, interval=2))
plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45, fontsize=9)
# 4. 市场因素分析
if enhancement_info:
factors = ['大盘趋势', '板块共振', '宏观环境', '美国降息', '基本面']
scores = [
enhancement_info['market_analysis']['overall_trend_strength'],
enhancement_info['sector_analysis']['resonance_score'],
enhancement_info['macro_analysis']['overall_macro_score'],
0.7 if enhancement_info['macro_analysis']['us_rate_cycle']['trend'] == '降息周期' else 0.3,
enhancement_info['fundamental_analysis']['fundamental_score']
]
colors_bars = [colors['historical'], colors['prediction'], colors['enhanced'], '#f39c12', '#9b59b6']
bars = ax4.bar(factors, scores, color=colors_bars, alpha=0.8, edgecolor='black', linewidth=1)
ax4.set_ylim(0, 1)
ax4.set_ylabel('评分', fontsize=12, fontweight='bold')
ax4.set_title('市场因素评分分析', fontweight='bold', fontsize=13)
ax4.grid(True, alpha=0.3, axis='y')
# 在柱状图上显示具体数值
for i, (bar, score) in enumerate(zip(bars, scores)):
height = bar.get_height()
ax4.text(bar.get_x() + bar.get_width() / 2., height + 0.02,
f'{score:.2f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
# 添加平均线
avg_score = np.mean(scores)
ax4.axhline(y=avg_score, color='red', linestyle='--', alpha=0.7,
label=f'平均分: {avg_score:.2f}')
ax4.legend(loc='upper right', fontsize=9)
plt.tight_layout()
# 保存图片
chart_filename = os.path.join(output_dir, f'{stock_code}_optimized_prediction.png')
plt.savefig(chart_filename, dpi=300, bbox_inches='tight', facecolor='white')
plt.close()
print(f"📊 优化版预测图表已保存: {chart_filename}")
return chart_filename
def mark_key_dates_safe(ax, future_dates, pred_df):
"""
🎯 安全版:标记关键日期和价格点,避免类型错误
"""
if len(future_dates) == 0 or len(pred_df) == 0:
return
try:
# 重置索引确保使用整数索引
pred_df_reset = pred_df.reset_index(drop=True)
# 获取最高点和最低点的整数索引
if hasattr(pred_df_reset['close'], 'idxmax'):
max_idx = pred_df_reset['close'].idxmax()
min_idx = pred_df_reset['close'].idxmin()
else:
# 备用方法
max_idx = np.argmax(pred_df_reset['close'].values)
min_idx = np.argmin(pred_df_reset['close'].values)
# 确保索引在有效范围内
max_idx = min(int(max_idx), len(future_dates) - 1)
min_idx = min(int(min_idx), len(future_dates) - 1)
# 标记最高点
if 0 <= max_idx < len(future_dates):
max_price = pred_df_reset['close'].iloc[max_idx]
ax.plot(future_dates[max_idx], max_price,
'v', color='red', markersize=8, label=f'最高点: {max_price:.2f}')
# 标记最低点
if 0 <= min_idx < len(future_dates):
min_price = pred_df_reset['close'].iloc[min_idx]
ax.plot(future_dates[min_idx], min_price,
'^', color='green', markersize=8, label=f'最低点: {min_price:.2f}')
# 标记预测结束点
if len(future_dates) > 0:
final_price = pred_df_reset['close'].iloc[-1]
ax.plot(future_dates[-1], final_price,
's', color='blue', markersize=6, label=f'最终预测: {final_price:.2f}')
except Exception as e:
print(f"⚠️ 标记关键日期时出现错误: {e}")
# 如果出错,跳过标记但不影响整体流程
# ==================== 主函数 ====================
def main():
"""主函数启动GUI界面"""
root = tk.Tk()
app = StockPredictorGUI(root)
root.mainloop()
if __name__ == "__main__":
main()