Files
Kronos/examples/run_backtest_kronos.py
billconan2017 cf9744a46b 修改内容
2025-10-10 16:32:17 +08:00

455 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# run_backtest.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
class KronosBacktester:
"""
Kronos模型回测类
"""
def __init__(self, data_dir, model_dir, initial_capital=100000):
"""
初始化回测器
参数:
data_dir: 数据目录
model_dir: 模型预测结果目录
initial_capital: 初始资金
"""
self.data_dir = data_dir
self.model_dir = model_dir
self.initial_capital = initial_capital
self.results = {}
def load_historical_data(self, stock_code):
"""
加载历史数据
"""
csv_file = os.path.join(self.data_dir, f"{stock_code}_stock_data.csv")
if not os.path.exists(csv_file):
raise FileNotFoundError(f"数据文件不存在: {csv_file}")
df = pd.read_csv(csv_file, encoding='utf-8-sig')
# 检查列名并标准化
column_mapping = {
'日期': 'date',
'开盘价': 'open',
'最高价': 'high',
'最低价': 'low',
'收盘价': 'close',
'成交量': 'volume',
'成交额': 'amount'
}
# 重命名列
for old_col, new_col in column_mapping.items():
if old_col in df.columns:
df = df.rename(columns={old_col: new_col})
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
df = df.sort_index()
print(f"✅ 加载历史数据: {len(df)} 条记录")
print(f"时间范围: {df.index.min()}{df.index.max()}")
return df
def load_predictions(self, stock_code):
"""
加载模型预测结果
"""
# 尝试不同的预测文件命名
pred_files = [
os.path.join(self.model_dir, f"{stock_code}_kronos_predictions.csv"),
os.path.join(self.model_dir, f"{stock_code}_detailed_predictions.csv"),
os.path.join(self.model_dir, f"{stock_code}_predictions.csv")
]
pred_df = None
for pred_file in pred_files:
if os.path.exists(pred_file):
pred_df = pd.read_csv(pred_file, encoding='utf-8-sig')
print(f"✅ 找到预测文件: {pred_file}")
break
if pred_df is None:
raise FileNotFoundError(f"未找到预测文件,请检查目录: {self.model_dir}")
# 标准化列名
column_mapping = {
'日期': 'date',
'预测收盘价': 'predicted_close',
'收盘价': 'predicted_close',
'预测成交量': 'predicted_volume',
'成交量': 'predicted_volume'
}
for old_col, new_col in column_mapping.items():
if old_col in pred_df.columns:
pred_df = pred_df.rename(columns={old_col: new_col})
pred_df['date'] = pd.to_datetime(pred_df['date'])
pred_df.set_index('date', inplace=True)
pred_df = pred_df.sort_index()
print(f"✅ 加载预测数据: {len(pred_df)} 条记录")
print(f"预测时间范围: {pred_df.index.min()}{pred_df.index.max()}")
return pred_df
def align_data(self, hist_df, pred_df):
"""
对齐历史数据和预测数据的时间范围
"""
# 找到历史数据的最后日期
last_hist_date = hist_df.index.max()
# 筛选预测数据,从历史数据结束后开始
pred_df_aligned = pred_df[pred_df.index > last_hist_date]
if len(pred_df_aligned) == 0:
# 如果没有未来的预测数据,使用所有预测数据
pred_df_aligned = pred_df.copy()
print("⚠️ 警告:预测数据没有未来的日期,使用所有预测数据")
print(f"✅ 数据对齐: 历史数据结束于 {last_hist_date}, 预测数据从 {pred_df_aligned.index.min()} 开始")
return pred_df_aligned
def calculate_trading_signals(self, hist_df, pred_df, threshold=0.02):
"""
计算交易信号
"""
# 对齐数据
pred_df = self.align_data(hist_df, pred_df)
# 合并历史数据和预测数据
combined = pd.concat([
hist_df[['close']].rename(columns={'close': 'actual'}),
pred_df[['predicted_close']].rename(columns={'predicted_close': 'predicted'})
], axis=1)
# 计算预测收益率
combined['pred_return'] = combined['predicted'].pct_change()
# 生成交易信号
combined['signal'] = 0
combined['signal'] = np.where(combined['pred_return'] > threshold, 1, # 买入信号
np.where(combined['pred_return'] < -threshold, -1, 0)) # 卖出信号
# 过滤信号:避免频繁交易
combined['position'] = combined['signal'].replace(to_replace=0, method='ffill').fillna(0)
return combined
def run_backtest(self, combined_df):
"""
运行回测
"""
# 初始化资金和持仓
capital = self.initial_capital
position = 0
trades = []
# 回测记录
backtest_results = pd.DataFrame(index=combined_df.index)
backtest_results['capital'] = capital
backtest_results['position'] = 0
backtest_results['returns'] = 0.0
backtest_results['price'] = combined_df['actual'].combine_first(combined_df['predicted'])
for i, (date, row) in enumerate(combined_df.iterrows()):
current_price = row['actual'] if not pd.isna(row['actual']) else row['predicted']
signal = row['position']
# 跳过无效价格
if pd.isna(current_price):
continue
# 执行交易
if i > 0: # 从第二天开始
prev_position = backtest_results['position'].iloc[i - 1] if i > 0 else 0
# 平仓信号
if prev_position != 0 and signal == 0:
# 平仓
capital = position * current_price
position = 0
trades.append({
'date': date,
'action': 'SELL',
'price': current_price,
'shares': prev_position,
'capital': capital
})
# 开仓信号
elif prev_position == 0 and signal != 0:
# 计算可买股数(假设全仓交易)
shares = int(capital / current_price)
if shares > 0:
position = shares * signal
capital -= shares * current_price
trades.append({
'date': date,
'action': 'BUY',
'price': current_price,
'shares': shares * signal,
'capital': capital
})
# 更新持仓市值
portfolio_value = capital + position * current_price
# 记录结果
backtest_results.loc[date, 'capital'] = portfolio_value
backtest_results.loc[date, 'position'] = position
backtest_results.loc[date, 'price'] = current_price
# 计算日收益率
if i > 0:
prev_value = backtest_results['capital'].iloc[i - 1]
if prev_value > 0:
backtest_results.loc[date, 'returns'] = (portfolio_value - prev_value) / prev_value
return backtest_results, trades
def calculate_metrics(self, backtest_results, trades):
"""
计算回测指标
"""
returns = backtest_results['returns'].replace([np.inf, -np.inf], np.nan).dropna()
if len(returns) == 0:
return {
'总收益率': 0,
'年化收益率': 0,
'波动率': 0,
'夏普比率': 0,
'最大回撤': 0,
'胜率': 0,
'平均交易收益': 0,
'交易次数': 0,
'最终资金': self.initial_capital
}
total_return = (backtest_results['capital'].iloc[-1] - self.initial_capital) / self.initial_capital
annual_return = (1 + total_return) ** (252 / len(returns)) - 1
# 波动率
volatility = returns.std() * np.sqrt(252)
# 夏普比率假设无风险利率为3%
risk_free_rate = 0.03
sharpe_ratio = (annual_return - risk_free_rate) / volatility if volatility > 0 else 0
# 最大回撤
cumulative_returns = (1 + returns).cumprod()
peak = cumulative_returns.expanding().max()
drawdown = (cumulative_returns - peak) / peak
max_drawdown = drawdown.min()
# 交易统计
trade_returns = []
buy_trades = [t for t in trades if t['action'] == 'BUY']
sell_trades = [t for t in trades if t['action'] == 'SELL']
for i in range(min(len(buy_trades), len(sell_trades))):
buy = buy_trades[i]
sell = sell_trades[i]
trade_return = (sell['price'] - buy['price']) / buy['price']
trade_returns.append(trade_return)
win_rate = len([r for r in trade_returns if r > 0]) / len(trade_returns) if trade_returns else 0
avg_trade_return = np.mean(trade_returns) if trade_returns else 0
metrics = {
'总收益率': total_return,
'年化收益率': annual_return,
'波动率': volatility,
'夏普比率': sharpe_ratio,
'最大回撤': max_drawdown,
'胜率': win_rate,
'平均交易收益': avg_trade_return,
'交易次数': len(trades),
'最终资金': backtest_results['capital'].iloc[-1]
}
return metrics
def plot_backtest_results(self, backtest_results, metrics, stock_code, output_dir):
"""
绘制回测结果图表
"""
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 12))
# 1. 资金曲线
ax1.plot(backtest_results.index, backtest_results['capital'],
linewidth=2, label='策略资金曲线', color='#1f77b4')
ax1.axhline(y=self.initial_capital, color='red', linestyle='--',
label=f'初始资金 ({self.initial_capital:,.0f}元)')
ax1.set_ylabel('资金 (元)', fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title(f'{stock_code} Kronos模型回测结果', fontsize=14, fontweight='bold')
# 2. 收益率曲线
cumulative_returns = (1 + backtest_results['returns'].fillna(0)).cumprod()
ax2.plot(backtest_results.index, cumulative_returns,
linewidth=2, label='策略累计收益', color='#2ca02c')
# 基准收益(买入持有)
price_returns = backtest_results['price'].pct_change().fillna(0)
benchmark_returns = (1 + price_returns).cumprod()
ax2.plot(backtest_results.index, benchmark_returns,
linewidth=2, label='基准收益(买入持有)', color='#ff7f0e', alpha=0.7)
ax2.set_ylabel('累计收益', fontsize=12)
ax2.legend()
ax2.grid(True, alpha=0.3)
# 3. 回撤曲线
peak = cumulative_returns.expanding().max()
drawdown = (cumulative_returns - peak) / peak
ax3.fill_between(backtest_results.index, drawdown, 0,
alpha=0.3, color='red', label='回撤')
ax3.set_ylabel('回撤', fontsize=12)
ax3.set_xlabel('日期', fontsize=12)
ax3.legend()
ax3.grid(True, alpha=0.3)
# 添加指标文本
metrics_text = (
f"总收益率: {metrics['总收益率']:.2%}\n"
f"年化收益率: {metrics['年化收益率']:.2%}\n"
f"夏普比率: {metrics['夏普比率']:.2f}\n"
f"最大回撤: {metrics['最大回撤']:.2%}\n"
f"胜率: {metrics['胜率']:.2%}\n"
f"交易次数: {metrics['交易次数']}\n"
f"最终资金: {metrics['最终资金']:,.0f}"
)
ax1.text(0.02, 0.98, metrics_text, transform=ax1.transAxes, fontsize=10,
verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3",
facecolor="lightyellow", alpha=0.8))
plt.tight_layout()
# 保存图表
os.makedirs(output_dir, exist_ok=True)
chart_file = os.path.join(output_dir, f'{stock_code}_backtest_results.png')
plt.savefig(chart_file, dpi=300, bbox_inches='tight')
print(f"📊 回测图表已保存: {chart_file}")
plt.show()
def run_complete_backtest(self, stock_code, output_dir, threshold=0.02):
"""
运行完整的回测流程
"""
print(f"🎯 开始 {stock_code} 回测分析")
print("=" * 50)
try:
# 1. 加载数据
print("步骤1: 加载历史数据和预测数据...")
hist_df = self.load_historical_data(stock_code)
pred_df = self.load_predictions(stock_code)
# 2. 计算交易信号
print("步骤2: 计算交易信号...")
combined_df = self.calculate_trading_signals(hist_df, pred_df, threshold)
# 3. 运行回测
print("步骤3: 运行回测...")
backtest_results, trades = self.run_backtest(combined_df)
# 4. 计算指标
print("步骤4: 计算回测指标...")
metrics = self.calculate_metrics(backtest_results, trades)
# 5. 绘制结果
print("步骤5: 生成回测图表...")
self.plot_backtest_results(backtest_results, metrics, stock_code, output_dir)
# 6. 打印详细报告
print("\n" + "=" * 70)
print(f"📊 {stock_code} 回测报告")
print("=" * 70)
for key, value in metrics.items():
if isinstance(value, float):
if '' in key or '收益' in key or '回撤' in key:
print(f" {key}: {value:.2%}")
else:
print(f" {key}: {value:.2f}")
else:
print(f" {key}: {value}")
print(f"\n交易记录 (共{len(trades)}次交易):")
for i, trade in enumerate(trades[-10:], 1): # 显示最后10次交易
print(f" 交易{i}: {trade['date'].strftime('%Y-%m-%d')} "
f"{trade['action']} {abs(trade['shares'])}股 @ {trade['price']:.2f}")
return metrics, backtest_results, trades
except Exception as e:
print(f"❌ 回测过程中出现错误: {e}")
import traceback
traceback.print_exc()
return None, None, None
def main():
"""
主函数运行Kronos模型回测
"""
# 配置参数
BACKTEST_CONFIG = {
"stock_code": "000831", # 要回测的股票代码
"data_dir": r"D:\lianghuajiaoyi\Kronos\examples\data", # 历史数据目录
"model_dir": r"D:\lianghuajiaoyi\Kronos\examples\yuce", # 模型预测结果目录
"output_dir": r"D:\lianghuajiaoyi\Kronos\examples\backtest", # 回测结果输出目录
"initial_capital": 100000, # 初始资金
"threshold": 0.02 # 交易阈值2%
}
print("🤖 Kronos模型回测系统")
print("=" * 50)
print(f"回测股票: {BACKTEST_CONFIG['stock_code']}")
print(f"初始资金: {BACKTEST_CONFIG['initial_capital']:,.0f}")
print(f"交易阈值: {BACKTEST_CONFIG['threshold']:.1%}")
print()
# 创建回测器并运行
backtester = KronosBacktester(
data_dir=BACKTEST_CONFIG["data_dir"],
model_dir=BACKTEST_CONFIG["model_dir"],
initial_capital=BACKTEST_CONFIG["initial_capital"]
)
metrics, results, trades = backtester.run_complete_backtest(
stock_code=BACKTEST_CONFIG["stock_code"],
output_dir=BACKTEST_CONFIG["output_dir"],
threshold=BACKTEST_CONFIG["threshold"]
)
if metrics:
print(f"\n{BACKTEST_CONFIG['stock_code']} 回测完成!")
print(f"📁 结果保存在: {BACKTEST_CONFIG['output_dir']}")
if __name__ == "__main__":
main()