mirror of
https://github.com/shiyu-coder/Kronos.git
synced 2026-06-05 22:30:55 +08:00
Add regression tests
This commit is contained in:
90
tests/data/generate_regression_output.py
Normal file
90
tests/data/generate_regression_output.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from model import Kronos, KronosPredictor, KronosTokenizer
|
||||
|
||||
|
||||
TEST_DATA_ROOT = Path(__file__).parent
|
||||
INPUT_DATA_PATH = TEST_DATA_ROOT / "regression_input.csv"
|
||||
OUTPUT_DATA_DIR = TEST_DATA_ROOT
|
||||
MAX_CTX_LEN = 512
|
||||
TEST_CTX_LEN = [512, 256]
|
||||
PRED_LEN = 8
|
||||
FEATURE_NAMES = ["open", "high", "low", "close", "volume", "amount"]
|
||||
|
||||
MODEL_REVISION = "901c26c1332695a2a8f243eb2f37243a37bea320"
|
||||
TOKENIZER_REVISION = "0e0117387f39004a9016484a186a908917e22426"
|
||||
SEED = 123
|
||||
|
||||
DEVICE = "cpu"
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.cudnn.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def generate_output(ctx_len: int) -> None:
|
||||
if ctx_len > MAX_CTX_LEN:
|
||||
raise ValueError(
|
||||
f"Context length for output generation ({ctx_len}) "
|
||||
f"cannot exceed maximum context length ({MAX_CTX_LEN})."
|
||||
)
|
||||
|
||||
context_df = df.iloc[:ctx_len].copy()
|
||||
future_timestamps = df["timestamps"].iloc[
|
||||
ctx_len : ctx_len + PRED_LEN
|
||||
].reset_index(drop=True)
|
||||
|
||||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", revision=TOKENIZER_REVISION)
|
||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-small", revision=MODEL_REVISION)
|
||||
tokenizer.eval()
|
||||
model.eval()
|
||||
|
||||
predictor = KronosPredictor(
|
||||
model, tokenizer, device=DEVICE, max_context=MAX_CTX_LEN
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_df = predictor.predict(
|
||||
df=context_df[FEATURE_NAMES].reset_index(drop=True),
|
||||
x_timestamp=context_df["timestamps"].reset_index(drop=True),
|
||||
y_timestamp=future_timestamps,
|
||||
pred_len=PRED_LEN,
|
||||
T=1.0,
|
||||
top_k=1,
|
||||
top_p=1.0,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
if pred_df.shape != (PRED_LEN, len(FEATURE_NAMES)):
|
||||
raise ValueError(f"Unexpected prediction shape: {pred_df.shape}")
|
||||
|
||||
output_df = pred_df.reset_index(drop=True)
|
||||
output_df["timestamps"] = future_timestamps
|
||||
output_df = output_df[["timestamps"] + FEATURE_NAMES]
|
||||
output_df.to_csv(OUTPUT_DATA_DIR / f"regression_output_{ctx_len}.csv", index=False)
|
||||
print(f"Saved {ctx_len} fixture to {OUTPUT_DATA_DIR / f'regression_output_{ctx_len}.csv'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(SEED)
|
||||
|
||||
|
||||
df = pd.read_csv(INPUT_DATA_PATH, parse_dates=["timestamps"])
|
||||
if df.shape[0] < MAX_CTX_LEN + PRED_LEN:
|
||||
raise ValueError(
|
||||
f"Input data must have at least {MAX_CTX_LEN + PRED_LEN} rows, "
|
||||
f"found {df.shape[0]} instead."
|
||||
)
|
||||
|
||||
for ctx_len in TEST_CTX_LEN:
|
||||
generate_output(ctx_len)
|
||||
2501
tests/data/regression_input.csv
Normal file
2501
tests/data/regression_input.csv
Normal file
File diff suppressed because it is too large
Load Diff
9
tests/data/regression_output_256.csv
Normal file
9
tests/data/regression_output_256.csv
Normal file
@@ -0,0 +1,9 @@
|
||||
timestamps,open,high,low,close,volume,amount
|
||||
2024-06-25 14:05:00,10.766402,10.778437,10.755835,10.769899,463.83276,479256.62
|
||||
2024-06-25 14:10:00,10.769842,10.7804785,10.75896,10.771648,415.90912,434510.62
|
||||
2024-06-25 14:15:00,10.771282,10.781633,10.760545,10.773098,396.62488,416206.88
|
||||
2024-06-25 14:20:00,10.772831,10.782868,10.761984,10.77445,389.24976,409554.56
|
||||
2024-06-25 14:25:00,10.774201,10.783865,10.763183,10.775418,386.3412,407075.44
|
||||
2024-06-25 14:30:00,10.774968,10.78441,10.763903,10.776,383.4024,404050.8
|
||||
2024-06-25 14:35:00,10.775348,10.7847595,10.764308,10.776471,377.25995,397440.12
|
||||
2024-06-25 14:40:00,10.775859,10.78527,10.764823,10.77709,369.78687,389529.8
|
||||
|
9
tests/data/regression_output_512.csv
Normal file
9
tests/data/regression_output_512.csv
Normal file
@@ -0,0 +1,9 @@
|
||||
timestamps,open,high,low,close,volume,amount
|
||||
2024-07-03 09:55:00,10.897451,10.931036,10.800024,10.917972,1545.1384,1665960.5
|
||||
2024-07-03 10:00:00,10.900613,10.907957,10.871778,10.884289,719.92456,792042.5
|
||||
2024-07-03 10:05:00,10.882399,10.890674,10.864932,10.87375,659.0906,716546.6
|
||||
2024-07-03 10:10:00,10.871227,10.881202,10.857913,10.867245,629.60645,681494.7
|
||||
2024-07-03 10:15:00,10.864513,10.875556,10.85328,10.863286,607.7948,656790.1
|
||||
2024-07-03 10:20:00,10.861447,10.872749,10.851165,10.86135,591.31,638401.9
|
||||
2024-07-03 10:25:00,10.860088,10.871324,10.850175,10.860315,580.63446,626458.9
|
||||
2024-07-03 10:30:00,10.858802,10.869965,10.849048,10.859139,572.116,616889.56
|
||||
|
138
tests/test_kronos_regression.py
Normal file
138
tests/test_kronos_regression.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import Kronos, KronosPredictor, KronosTokenizer
|
||||
|
||||
TEST_DATA_ROOT = Path(__file__).parent / "data"
|
||||
INPUT_DATA_PATH = TEST_DATA_ROOT / "regression_input.csv"
|
||||
|
||||
# Regression test configuration
|
||||
OUTPUT_DATA_DIR = TEST_DATA_ROOT
|
||||
MAX_CTX_LEN = 512
|
||||
TEST_CTX_LEN = [512, 256]
|
||||
PRED_LEN = 8
|
||||
REL_TOLERANCE = 1e-5
|
||||
FEATURE_NAMES = ["open", "high", "low", "close", "volume", "amount"]
|
||||
|
||||
# MSE regression test configuration
|
||||
MSE_SAMPLE_SIZE = 8
|
||||
MSE_SAMPLE_CTX_LEN = 512
|
||||
MSE_PRED_LEN = 30
|
||||
MSE_EXPECTED = 0.00559805
|
||||
MSE_TOLERANCE = 0.000001
|
||||
MSE_FEATURE_NAMES = ["open", "high", "low", "close"]
|
||||
|
||||
MODEL_REVISION = "901c26c1332695a2a8f243eb2f37243a37bea320"
|
||||
TOKENIZER_REVISION = "0e0117387f39004a9016484a186a908917e22426"
|
||||
SEED = 123
|
||||
DEVICE = "cpu"
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.cudnn.is_available():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("context_len", TEST_CTX_LEN)
|
||||
def test_kronos_predictor_regression(context_len):
|
||||
set_seed(SEED)
|
||||
|
||||
expected_output_path = OUTPUT_DATA_DIR / f"regression_output_{context_len}.csv"
|
||||
df = pd.read_csv(INPUT_DATA_PATH, parse_dates=["timestamps"])
|
||||
expected_df = pd.read_csv(expected_output_path, parse_dates=["timestamps"])
|
||||
|
||||
if df.shape[0] < context_len + len(expected_df):
|
||||
raise ValueError("Example data does not contain enough rows for the regression test.")
|
||||
|
||||
context_df = df.iloc[:context_len].copy()
|
||||
context_features = context_df[FEATURE_NAMES].reset_index(drop=True)
|
||||
x_timestamp = context_df["timestamps"].reset_index(drop=True)
|
||||
future_timestamp = df["timestamps"].iloc[context_len:context_len + len(expected_df)].reset_index(drop=True)
|
||||
expected = expected_df[FEATURE_NAMES].values.astype(np.float32)
|
||||
|
||||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", revision=TOKENIZER_REVISION)
|
||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-small", revision=MODEL_REVISION)
|
||||
tokenizer.eval()
|
||||
model.eval()
|
||||
|
||||
predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=MAX_CTX_LEN)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_df = predictor.predict(
|
||||
df=context_features,
|
||||
x_timestamp=x_timestamp,
|
||||
y_timestamp=future_timestamp,
|
||||
pred_len=expected.shape[0],
|
||||
T=1.0,
|
||||
top_k=1,
|
||||
top_p=1.0,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
obtained = pred_df[FEATURE_NAMES].to_numpy(dtype=np.float32)
|
||||
|
||||
abs_diff = np.abs(obtained - expected)
|
||||
rel_diff = abs_diff / (np.abs(expected) + 1e-9)
|
||||
print(f"Abs diff: {np.max(abs_diff)}, Rel diff: {np.max(rel_diff)}")
|
||||
|
||||
np.testing.assert_allclose(obtained, expected, rtol=REL_TOLERANCE)
|
||||
|
||||
|
||||
def test_kronos_predictor_mse():
|
||||
set_seed(SEED)
|
||||
|
||||
df = pd.read_csv(INPUT_DATA_PATH, parse_dates=["timestamps"])
|
||||
if df.shape[0] <= MSE_SAMPLE_CTX_LEN + MSE_PRED_LEN:
|
||||
raise ValueError("Example data does not contain enough rows for the random sample regression test.")
|
||||
|
||||
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", revision=TOKENIZER_REVISION)
|
||||
model = Kronos.from_pretrained("NeoQuasar/Kronos-small", revision=MODEL_REVISION)
|
||||
tokenizer.eval()
|
||||
model.eval()
|
||||
|
||||
predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=MAX_CTX_LEN)
|
||||
|
||||
valid_region = df.iloc[MSE_SAMPLE_CTX_LEN : df.shape[0] - MSE_PRED_LEN]
|
||||
if valid_region.shape[0] < MSE_SAMPLE_SIZE:
|
||||
raise ValueError("Not enough data points to draw the requested random samples.")
|
||||
|
||||
sampled_rows = valid_region.sample(n=MSE_SAMPLE_SIZE, random_state=SEED).sort_index()
|
||||
|
||||
mse_values = []
|
||||
sample_indices = sampled_rows.index.to_list()
|
||||
with torch.no_grad():
|
||||
for row_idx in tqdm(sample_indices):
|
||||
context_slice = df.iloc[row_idx - MSE_SAMPLE_CTX_LEN : row_idx].copy()
|
||||
future_slice = df.iloc[row_idx : row_idx + MSE_PRED_LEN].copy()
|
||||
|
||||
pred_df = predictor.predict(
|
||||
df=context_slice[FEATURE_NAMES].reset_index(drop=True),
|
||||
x_timestamp=context_slice["timestamps"].reset_index(drop=True),
|
||||
y_timestamp=future_slice["timestamps"].reset_index(drop=True),
|
||||
pred_len=MSE_PRED_LEN,
|
||||
T=1.0,
|
||||
top_k=1,
|
||||
top_p=1.0,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
obtained = pred_df[MSE_FEATURE_NAMES].to_numpy(dtype=np.float32)
|
||||
expected = future_slice[MSE_FEATURE_NAMES].to_numpy(dtype=np.float32)
|
||||
mse_values.append(float(np.mean((obtained - expected) ** 2)))
|
||||
|
||||
assert len(mse_values) == MSE_SAMPLE_SIZE, f"Expected {MSE_SAMPLE_SIZE} MSE values, got {len(mse_values)}."
|
||||
|
||||
mse = np.mean(mse_values).item()
|
||||
mse_diff = mse - MSE_EXPECTED
|
||||
print(f"Average MSE: {mse} (Diff vs expected: {mse_diff:+})")
|
||||
|
||||
assert abs(mse_diff) <= MSE_TOLERANCE, f"MSE {mse} differs from expected {MSE_EXPECTED}"
|
||||
Reference in New Issue
Block a user