diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 86871a82..0ef663f1 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -252,17 +252,16 @@ pub fn detect_provider_kind(model: &str) -> ProviderKind { #[must_use] pub fn max_tokens_for_model(model: &str) -> u32 { - model_token_limit(model).map_or_else( - || { - let canonical = resolve_model_alias(model); - if canonical.contains("opus") { - 32_000 - } else { - 64_000 - } - }, - |limit| limit.max_output_tokens, - ) + let canonical = resolve_model_alias(model); + let heuristic = if canonical.contains("opus") { + 32_000 + } else { + 64_000 + }; + + model_token_limit(model) + .map(|limit| heuristic.min(limit.max_output_tokens)) + .unwrap_or(heuristic) } /// Returns the effective max output tokens for a model, preferring a plugin @@ -276,7 +275,8 @@ pub fn max_tokens_for_model_with_override(model: &str, plugin_override: Option Option { let canonical = resolve_model_alias(model); - match canonical.as_str() { + let base_model = canonical.rsplit('/').next().unwrap_or(canonical.as_str()); + match base_model { "claude-opus-4-6" => Some(ModelTokenLimit { max_output_tokens: 32_000, context_window_tokens: 200_000, @@ -289,6 +289,20 @@ pub fn model_token_limit(model: &str) -> Option { max_output_tokens: 64_000, context_window_tokens: 131_072, }), + // GPT-4.1 family via the OpenAI API. + "gpt-4.1" | "gpt-4.1-mini" | "gpt-4.1-nano" => Some(ModelTokenLimit { + max_output_tokens: 32_768, + context_window_tokens: 1_047_576, + }), + // GPT-5.4 family via the OpenAI API. + "gpt-5.4" => Some(ModelTokenLimit { + max_output_tokens: 128_000, + context_window_tokens: 1_000_000, + }), + "gpt-5.4-mini" | "gpt-5.4-nano" => Some(ModelTokenLimit { + max_output_tokens: 128_000, + context_window_tokens: 400_000, + }), // Kimi models via DashScope (Moonshot AI) // Source: https://platform.moonshot.cn/docs/intro "kimi-k2.5" | "kimi-k1.5" => Some(ModelTokenLimit { @@ -614,6 +628,15 @@ mod tests { fn keeps_existing_max_token_heuristic() { assert_eq!(max_tokens_for_model("opus"), 32_000); assert_eq!(max_tokens_for_model("grok-3"), 64_000); + assert_eq!(max_tokens_for_model("gpt-5.4"), 64_000); + } + + #[test] + fn caps_default_max_tokens_to_openai_model_limits() { + assert_eq!(max_tokens_for_model("gpt-4.1-mini"), 32_768); + assert_eq!(max_tokens_for_model("openai/gpt-4.1-mini"), 32_768); + assert_eq!(max_tokens_for_model("gpt-5.4"), 64_000); + assert_eq!(max_tokens_for_model("openai/gpt-5.4"), 64_000); } #[test] @@ -680,6 +703,18 @@ mod tests { .context_window_tokens, 131_072 ); + assert_eq!( + model_token_limit("openai/gpt-4.1-mini") + .expect("openai/gpt-4.1-mini should be registered") + .context_window_tokens, + 1_047_576 + ); + assert_eq!( + model_token_limit("gpt-5.4") + .expect("gpt-5.4 should be registered") + .context_window_tokens, + 1_000_000 + ); } #[test] @@ -728,6 +763,42 @@ mod tests { } } + #[test] + fn preflight_blocks_oversized_requests_for_gpt_5_4() { + let request = MessageRequest { + model: "gpt-5.4".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(3_900_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: None, + tool_choice: None, + stream: true, + ..Default::default() + }; + + let error = preflight_message_request(&request) + .expect_err("oversized gpt-5.4 request should be rejected before the provider call"); + + match error { + ApiError::ContextWindowExceeded { + model, + requested_output_tokens, + context_window_tokens, + .. + } => { + assert_eq!(model, "gpt-5.4"); + assert_eq!(requested_output_tokens, 64_000); + assert_eq!(context_window_tokens, 1_000_000); + } + other => panic!("expected context-window preflight failure, got {other:?}"), + } + } + #[test] fn preflight_skips_unknown_models() { let request = MessageRequest { diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index e2c889e9..e46ec753 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -148,11 +148,7 @@ impl ModelProvenance { } fn max_tokens_for_model(model: &str) -> u32 { - if model.contains("opus") { - 32_000 - } else { - 64_000 - } + api::max_tokens_for_model(model) } // Build-time constants injected by build.rs (fall back to static values when // build.rs hasn't run, e.g. in doc-test or unusual toolchain environments).