mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-06 19:39:48 +08:00
feat: add validation for Claude streaming responses
- Implemented `validateClaudeStreamingResponse` to ensure upstream streaming data integrity. - Added new tests to verify response validation, including empty streams, error events, incomplete streams, and valid streams. - Integrated validation logic into the Claude executor's streaming handler, returning detailed errors for malformed upstream data. Fixed: #2193
This commit is contained in:
@@ -285,6 +285,10 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if stream {
|
||||
if errValidate := validateClaudeStreamingResponse(data); errValidate != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errValidate)
|
||||
return resp, errValidate
|
||||
}
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
@@ -533,6 +537,64 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func validateClaudeStreamingResponse(data []byte) error {
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
|
||||
hasData := false
|
||||
hasMessageStart := false
|
||||
hasMessageDelta := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := bytes.TrimSpace(scanner.Bytes())
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
payload := bytes.TrimSpace(line[len("data:"):])
|
||||
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
hasData = true
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned malformed stream data"}
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
switch root.Get("type").String() {
|
||||
case "error":
|
||||
message := strings.TrimSpace(root.Get("error.message").String())
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(root.Get("error.type").String())
|
||||
}
|
||||
if message == "" {
|
||||
message = "unknown upstream error"
|
||||
}
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned error event: " + message}
|
||||
case "message_start":
|
||||
message := root.Get("message")
|
||||
if strings.TrimSpace(message.Get("id").String()) == "" || strings.TrimSpace(message.Get("model").String()) == "" {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream message_start is missing id or model"}
|
||||
}
|
||||
hasMessageStart = true
|
||||
case "message_delta":
|
||||
hasMessageDelta = true
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
return errScan
|
||||
}
|
||||
if !hasData {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned empty stream response"}
|
||||
}
|
||||
if !hasMessageStart {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response is missing message_start"}
|
||||
}
|
||||
if !hasMessageDelta {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response ended before message completion"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
|
||||
@@ -936,6 +936,113 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsEmptyClaudeStream(t *testing.T) {
|
||||
_, err := executeOpenAIChatCompletionThroughClaude(t, "")
|
||||
if err == nil {
|
||||
t.Fatal("Execute error = nil, want empty stream error")
|
||||
}
|
||||
assertStatusErr(t, err, http.StatusBadGateway)
|
||||
if !strings.Contains(err.Error(), "empty stream response") {
|
||||
t.Fatalf("Execute error = %q, want empty stream response", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsClaudeErrorEvent(t *testing.T) {
|
||||
body := `data: {"type":"error","error":{"type":"overloaded_error","message":"upstream overloaded"}}` + "\n"
|
||||
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
|
||||
if err == nil {
|
||||
t.Fatal("Execute error = nil, want upstream error event")
|
||||
}
|
||||
assertStatusErr(t, err, http.StatusBadGateway)
|
||||
if !strings.Contains(err.Error(), "upstream overloaded") {
|
||||
t.Fatalf("Execute error = %q, want upstream overloaded", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsIncompleteClaudeStream(t *testing.T) {
|
||||
body := strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
|
||||
`data: {"type":"message_stop"}`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
|
||||
if err == nil {
|
||||
t.Fatal("Execute error = nil, want incomplete stream error")
|
||||
}
|
||||
assertStatusErr(t, err, http.StatusBadGateway)
|
||||
if !strings.Contains(err.Error(), "ended before message completion") {
|
||||
t.Fatalf("Execute error = %q, want incomplete stream error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamConvertsValidClaudeStream(t *testing.T) {
|
||||
body := strings.Join([]string{
|
||||
`event: message_start`,
|
||||
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
|
||||
`event: content_block_delta`,
|
||||
`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ok"}}`,
|
||||
`event: message_delta`,
|
||||
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":2,"output_tokens":1}}`,
|
||||
`event: message_stop`,
|
||||
`data: {"type":"message_stop"}`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
resp, err := executeOpenAIChatCompletionThroughClaude(t, body)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "id").String(); got != "msg_123" {
|
||||
t.Fatalf("response id = %q, want msg_123; payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-3-5-sonnet-20241022" {
|
||||
t.Fatalf("response model = %q, want claude-3-5-sonnet-20241022", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "ok" {
|
||||
t.Fatalf("response content = %q, want ok", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "usage.total_tokens").Int(); got != 3 {
|
||||
t.Fatalf("usage.total_tokens = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func executeOpenAIChatCompletionThroughClaude(t *testing.T, upstreamBody string) (cliproxyexecutor.Response, error) {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte(upstreamBody))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
return executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
}
|
||||
|
||||
func assertStatusErr(t *testing.T, err error, want int) {
|
||||
t.Helper()
|
||||
|
||||
status, ok := err.(interface{ StatusCode() int })
|
||||
if !ok {
|
||||
t.Fatalf("error %T does not expose StatusCode", err)
|
||||
}
|
||||
if got := status.StatusCode(); got != want {
|
||||
t.Fatalf("StatusCode() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
Reference in New Issue
Block a user