diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 9a304788..acf368ab 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -280,6 +280,7 @@ "dynamic_allowed": true, "levels": [ "low", + "medium", "high" ] } @@ -554,6 +555,7 @@ "dynamic_allowed": true, "levels": [ "low", + "medium", "high" ] } @@ -610,6 +612,8 @@ "dynamic_allowed": true, "levels": [ "minimal", + "low", + "medium", "high" ] } @@ -838,6 +842,7 @@ "dynamic_allowed": true, "levels": [ "low", + "medium", "high" ] } @@ -896,6 +901,8 @@ "dynamic_allowed": true, "levels": [ "minimal", + "low", + "medium", "high" ] } @@ -1070,6 +1077,8 @@ "dynamic_allowed": true, "levels": [ "minimal", + "low", + "medium", "high" ] } @@ -1371,6 +1380,75 @@ "xhigh" ] } + }, + { + "id": "gpt-5.3-codex", + "object": "model", + "created": 1770307200, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.3 Codex", + "version": "gpt-5.3", + "description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.4", + "object": "model", + "created": 1772668800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4", + "version": "gpt-5.4", + "description": "Stable version of GPT 5.4", + "context_length": 1050000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.4-mini", + "object": "model", + "created": 1773705600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4 Mini", + "version": "gpt-5.4-mini", + "description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-team": [ @@ -1623,6 +1701,29 @@ "xhigh" ] } + }, + { + "id": "gpt-5.4-mini", + "object": "model", + "created": 1773705600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4 Mini", + "version": "gpt-5.4-mini", + "description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-plus": [ @@ -1898,6 +1999,29 @@ "xhigh" ] } + }, + { + "id": "gpt-5.4-mini", + "object": "model", + "created": 1773705600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4 Mini", + "version": "gpt-5.4-mini", + "description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-pro": [ @@ -2173,55 +2297,40 @@ "xhigh" ] } + }, + { + "id": "gpt-5.4-mini", + "object": "model", + "created": 1773705600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4 Mini", + "version": "gpt-5.4-mini", + "description": "GPT-5.4 mini brings the strengths of GPT-5.4 to a faster, more efficient model designed for high-volume workloads.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "qwen": [ - { - "id": "qwen3-coder-plus", - "object": "model", - "created": 1753228800, - "owned_by": "qwen", - "type": "qwen", - "display_name": "Qwen3 Coder Plus", - "version": "3.0", - "description": "Advanced code generation and understanding model", - "context_length": 32768, - "max_completion_tokens": 8192, - "supported_parameters": [ - "temperature", - "top_p", - "max_tokens", - "stream", - "stop" - ] - }, - { - "id": "qwen3-coder-flash", - "object": "model", - "created": 1753228800, - "owned_by": "qwen", - "type": "qwen", - "display_name": "Qwen3 Coder Flash", - "version": "3.0", - "description": "Fast code generation model", - "context_length": 8192, - "max_completion_tokens": 2048, - "supported_parameters": [ - "temperature", - "top_p", - "max_tokens", - "stream", - "stop" - ] - }, { "id": "coder-model", "object": "model", "created": 1771171200, "owned_by": "qwen", "type": "qwen", - "display_name": "Qwen 3.5 Plus", - "version": "3.5", + "display_name": "Qwen 3.6 Plus", + "version": "3.6", "description": "efficient hybrid model with leading coding performance", "context_length": 1048576, "max_completion_tokens": 65536, @@ -2232,25 +2341,6 @@ "stream", "stop" ] - }, - { - "id": "vision-model", - "object": "model", - "created": 1758672000, - "owned_by": "qwen", - "type": "qwen", - "display_name": "Qwen3 Vision Model", - "version": "3.0", - "description": "Vision model model", - "context_length": 32768, - "max_completion_tokens": 2048, - "supported_parameters": [ - "temperature", - "top_p", - "max_tokens", - "stream", - "stop" - ] } ], "iflow": [ @@ -2639,11 +2729,12 @@ "context_length": 1048576, "max_completion_tokens": 65535, "thinking": { - "min": 128, - "max": 32768, + "min": 1, + "max": 65535, "dynamic_allowed": true, "levels": [ "low", + "medium", "high" ] } @@ -2659,11 +2750,12 @@ "context_length": 1048576, "max_completion_tokens": 65535, "thinking": { - "min": 128, - "max": 32768, + "min": 1, + "max": 65535, "dynamic_allowed": true, "levels": [ "low", + "medium", "high" ] } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 56c2c540..7b2e5d8d 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -137,6 +137,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) + body = normalizeClaudeTemperatureForThinking(body) // Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support) if countCacheControls(body) == 0 { @@ -307,6 +308,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) + body = normalizeClaudeTemperatureForThinking(body) // Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support) if countCacheControls(body) == 0 { @@ -651,6 +653,25 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte { return body } +// normalizeClaudeTemperatureForThinking keeps Anthropic message requests valid when +// thinking is enabled. Anthropic rejects temperatures other than 1 when +// thinking.type is enabled/adaptive/auto. +func normalizeClaudeTemperatureForThinking(body []byte) []byte { + if !gjson.GetBytes(body, "temperature").Exists() { + return body + } + + thinkingType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "thinking.type").String())) + switch thinkingType { + case "enabled", "adaptive", "auto": + if temp := gjson.GetBytes(body, "temperature"); temp.Exists() && temp.Type == gjson.Number && temp.Float() == 1 { + return body + } + body, _ = sjson.SetBytes(body, "temperature", 1) + } + return body +} + type compositeReadCloser struct { io.Reader closers []func() error diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 89bab2aa..74cec0a3 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -1833,3 +1833,43 @@ func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmi t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got) } } + +func TestNormalizeClaudeTemperatureForThinking_AdaptiveCoercesToOne(t *testing.T) { + payload := []byte(`{"temperature":0,"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`) + out := normalizeClaudeTemperatureForThinking(payload) + + if got := gjson.GetBytes(out, "temperature").Float(); got != 1 { + t.Fatalf("temperature = %v, want 1", got) + } +} + +func TestNormalizeClaudeTemperatureForThinking_EnabledCoercesToOne(t *testing.T) { + payload := []byte(`{"temperature":0.2,"thinking":{"type":"enabled","budget_tokens":2048}}`) + out := normalizeClaudeTemperatureForThinking(payload) + + if got := gjson.GetBytes(out, "temperature").Float(); got != 1 { + t.Fatalf("temperature = %v, want 1", got) + } +} + +func TestNormalizeClaudeTemperatureForThinking_NoThinkingLeavesTemperatureAlone(t *testing.T) { + payload := []byte(`{"temperature":0,"messages":[{"role":"user","content":"hi"}]}`) + out := normalizeClaudeTemperatureForThinking(payload) + + if got := gjson.GetBytes(out, "temperature").Float(); got != 0 { + t.Fatalf("temperature = %v, want 0", got) + } +} + +func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOriginalTemperature(t *testing.T) { + payload := []byte(`{"temperature":0,"thinking":{"type":"adaptive"},"output_config":{"effort":"max"},"tool_choice":{"type":"any"}}`) + out := disableThinkingIfToolChoiceForced(payload) + out = normalizeClaudeTemperatureForThinking(out) + + if gjson.GetBytes(out, "thinking").Exists() { + t.Fatalf("thinking should be removed when tool_choice forces tool use") + } + if got := gjson.GetBytes(out, "temperature").Float(); got != 0 { + t.Fatalf("temperature = %v, want 0", got) + } +} diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 7b9fffc5..f771099c 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -176,12 +176,6 @@ func timeUntilNextDay() time.Duration { func ensureQwenSystemMessage(payload []byte) ([]byte, error) { messages := gjson.GetBytes(payload, "messages") if messages.Exists() && messages.IsArray() { - for _, msg := range messages.Array() { - if strings.EqualFold(msg.Get("role").String(), "system") { - return payload, nil - } - } - var buf bytes.Buffer buf.WriteByte('[') buf.Write(qwenDefaultSystemMessage) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 1080f5cd..2f6b14a7 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -379,7 +379,7 @@ func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bo for _, item := range nextInput.Array() { switch strings.TrimSpace(item.Get("type").String()) { - case "function_call": + case "function_call", "custom_tool_call": return true case "message": role := strings.TrimSpace(item.Get("role").String()) @@ -431,7 +431,7 @@ func dedupeFunctionCallsByCallID(rawArray string) (string, error) { continue } itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) - if itemType == "function_call" { + if isResponsesToolCallType(itemType) { callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID != "" { if _, ok := seenCallIDs[callID]; ok { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 6fce1bf1..ecfc90b3 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -520,6 +520,92 @@ func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *te } } +func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolOutput(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"}]}`) + warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm) + if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" { + t.Fatalf("expected warmup output to remain") + } + + raw := []byte(`{"input":[{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 3 { + t.Fatalf("repaired input len = %d, want 3", len(input)) + } + if input[0].Get("type").String() != "custom_tool_call" || input[0].Get("call_id").String() != "call-1" { + t.Fatalf("unexpected first item: %s", input[0].Raw) + } + if input[1].Get("type").String() != "custom_tool_call_output" || input[1].Get("call_id").String() != "call-1" { + t.Fatalf("missing inserted output: %s", input[1].Raw) + } + if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[2].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolCall(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + raw := []byte(`{"input":[{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 1 { + t.Fatalf("repaired input len = %d, want 1", len(input)) + } + if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" { + t.Fatalf("unexpected remaining item: %s", input[0].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolCallForOrphanOutput(t *testing.T) { + outputCache := newWebsocketToolOutputCache(time.Minute, 10) + callCache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + callCache.record(sessionKey, "call-1", []byte(`{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"}`)) + + raw := []byte(`{"input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 3 { + t.Fatalf("repaired input len = %d, want 3", len(input)) + } + if input[0].Get("type").String() != "custom_tool_call" || input[0].Get("call_id").String() != "call-1" { + t.Fatalf("missing inserted call: %s", input[0].Raw) + } + if input[1].Get("type").String() != "custom_tool_call_output" || input[1].Get("call_id").String() != "call-1" { + t.Fatalf("unexpected output item: %s", input[1].Raw) + } + if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" { + t.Fatalf("unexpected trailing item: %s", input[2].Raw) + } +} + +func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolOutputWhenCallMissing(t *testing.T) { + outputCache := newWebsocketToolOutputCache(time.Minute, 10) + callCache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + raw := []byte(`{"input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`) + repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw) + + input := gjson.GetBytes(repaired, "input").Array() + if len(input) != 1 { + t.Fatalf("repaired input len = %d, want 1", len(input)) + } + if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" { + t.Fatalf("unexpected remaining item: %s", input[0].Raw) + } +} + func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) { cache := newWebsocketToolOutputCache(time.Minute, 10) sessionKey := "session-1" @@ -536,6 +622,38 @@ func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) { } } +func TestRecordResponsesWebsocketCustomToolCallsFromCompletedPayloadWithCache(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch","input":"*** Begin Patch"}]}}`) + recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload) + + cached, ok := cache.get(sessionKey, "call-1") + if !ok { + t.Fatalf("expected cached custom tool call") + } + if gjson.GetBytes(cached, "type").String() != "custom_tool_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" { + t.Fatalf("unexpected cached custom tool call: %s", cached) + } +} + +func TestRecordResponsesWebsocketCustomToolCallsFromOutputItemDoneWithCache(t *testing.T) { + cache := newWebsocketToolOutputCache(time.Minute, 10) + sessionKey := "session-1" + + payload := []byte(`{"type":"response.output_item.done","item":{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch","input":"*** Begin Patch"}}`) + recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload) + + cached, ok := cache.get(sessionKey, "call-1") + if !ok { + t.Fatalf("expected cached custom tool call") + } + if gjson.GetBytes(cached, "type").String() != "custom_tool_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" { + t.Fatalf("unexpected cached custom tool call: %s", cached) + } +} + func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { gin.SetMode(gin.TestMode) @@ -1023,6 +1141,161 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t } } +func TestNormalizeResponsesWebsocketRequestTreatsCustomToolTranscriptReplacementAsReset(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`) + lastResponseOutput := []byte(`[ + {"type":"message","id":"assistant-1","role":"assistant"} + ]`) + raw := []byte(`{"type":"response.create","input":[{"type":"custom_tool_call","id":"ctc-compact","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-compact","call_id":"call-1"},{"type":"message","id":"msg-2"}]}`) + + normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + if gjson.GetBytes(normalized, "previous_response_id").Exists() { + t.Fatalf("previous_response_id must not exist in transcript replacement mode") + } + items := gjson.GetBytes(normalized, "input").Array() + if len(items) != 3 { + t.Fatalf("replacement input len = %d, want 3: %s", len(items), normalized) + } + if items[0].Get("id").String() != "ctc-compact" || + items[1].Get("id").String() != "tool-out-compact" || + items[2].Get("id").String() != "msg-2" { + t.Fatalf("replacement transcript was not preserved as-is: %s", normalized) + } + if !bytes.Equal(next, normalized) { + t.Fatalf("next request snapshot should match replacement request") + } +} + +func TestNormalizeResponsesWebsocketRequestDropsDuplicateCustomToolCallsByCallID(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"}]}`) + lastResponseOutput := []byte(`[ + {"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"} + ]`) + raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`) + + normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + + items := gjson.GetBytes(normalized, "input").Array() + if len(items) != 3 { + t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized) + } + if items[0].Get("id").String() != "ctc-1" || + items[1].Get("id").String() != "tool-out-1" || + items[2].Get("id").String() != "msg-2" { + t.Fatalf("unexpected merged input order: %s", normalized) + } +} + +func TestResponsesWebsocketCompactionResetsTurnStateOnCustomToolTranscriptReplacement(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &websocketCompactionCaptureExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + router.POST("/v1/responses/compact", h.Compact) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + requests := []string{ + `{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`, + `{"type":"response.create","input":[{"type":"custom_tool_call_output","call_id":"call-1","id":"tool-out-1"}]}`, + } + for i := range requests { + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil { + t.Fatalf("write websocket message %d: %v", i+1, errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message %d: %v", i+1, errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted) + } + } + + compactResp, errPost := server.Client().Post( + server.URL+"/v1/responses/compact", + "application/json", + strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`), + ) + if errPost != nil { + t.Fatalf("compact request failed: %v", errPost) + } + if errClose := compactResp.Body.Close(); errClose != nil { + t.Fatalf("close compact response body: %v", errClose) + } + if compactResp.StatusCode != http.StatusOK { + t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK) + } + + postCompact := `{"type":"response.create","input":[{"type":"custom_tool_call","id":"ctc-compact","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-compact","call_id":"call-1"},{"type":"message","id":"msg-2"}]}` + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil { + t.Fatalf("write post-compact websocket message: %v", errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read post-compact websocket message: %v", errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted) + } + + executor.mu.Lock() + defer executor.mu.Unlock() + + if executor.compactPayload == nil { + t.Fatalf("compact payload was not captured") + } + if len(executor.streamPayloads) != 3 { + t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads)) + } + + merged := executor.streamPayloads[2] + items := gjson.GetBytes(merged, "input").Array() + if len(items) != 3 { + t.Fatalf("merged input len = %d, want 3: %s", len(items), merged) + } + if items[0].Get("id").String() != "ctc-compact" || + items[1].Get("id").String() != "tool-out-compact" || + items[2].Get("id").String() != "msg-2" { + t.Fatalf("unexpected post-compact input order: %s", merged) + } + if items[0].Get("call_id").String() != "call-1" { + t.Fatalf("post-compact custom tool call id = %s, want call-1", items[0].Get("call_id").String()) + } +} + func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go index 530aca96..1a5772ec 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go @@ -266,15 +266,15 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa continue } itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) - switch itemType { - case "function_call_output": + switch { + case isResponsesToolCallOutputType(itemType): callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID == "" { continue } outputPresent[callID] = struct{}{} outputCache.record(sessionKey, callID, item) - case "function_call": + case isResponsesToolCallType(itemType): callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID == "" { continue @@ -293,7 +293,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa continue } itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) - if itemType == "function_call_output" { + if isResponsesToolCallOutputType(itemType) { callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) if callID == "" { // Upstream rejects tool outputs without a call_id; drop it. @@ -325,7 +325,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa // Drop orphaned function_call_output items; upstream rejects transcripts with missing calls. continue } - if itemType != "function_call" { + if !isResponsesToolCallType(itemType) { filtered = append(filtered, item) continue } @@ -376,7 +376,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO return } for _, item := range output.Array() { - if strings.TrimSpace(item.Get("type").String()) != "function_call" { + if !isResponsesToolCallType(item.Get("type").String()) { continue } callID := strings.TrimSpace(item.Get("call_id").String()) @@ -390,7 +390,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO if !item.Exists() || !item.IsObject() { return } - if strings.TrimSpace(item.Get("type").String()) != "function_call" { + if !isResponsesToolCallType(item.Get("type").String()) { return } callID := strings.TrimSpace(item.Get("call_id").String()) @@ -400,3 +400,21 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO cache.record(sessionKey, callID, json.RawMessage(item.Raw)) } } + +func isResponsesToolCallType(itemType string) bool { + switch strings.TrimSpace(itemType) { + case "function_call", "custom_tool_call": + return true + default: + return false + } +} + +func isResponsesToolCallOutputType(itemType string) bool { + switch strings.TrimSpace(itemType) { + case "function_call_output", "custom_tool_call_output": + return true + default: + return false + } +}