mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-05 10:41:36 +08:00
fix(openai): repair empty responses stream output
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -45,7 +46,9 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
|
||||
}
|
||||
|
||||
type responsesSSEFramer struct {
|
||||
pending []byte
|
||||
pending []byte
|
||||
outputItems map[int][]byte
|
||||
outputOrder []int
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
|
||||
@@ -61,7 +64,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
|
||||
if frameLen == 0 {
|
||||
break
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending[:frameLen])
|
||||
f.writeFrame(w, f.pending[:frameLen])
|
||||
copy(f.pending, f.pending[frameLen:])
|
||||
f.pending = f.pending[:len(f.pending)-frameLen]
|
||||
}
|
||||
@@ -72,7 +75,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
|
||||
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
|
||||
return
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending)
|
||||
f.writeFrame(w, f.pending)
|
||||
f.pending = f.pending[:0]
|
||||
}
|
||||
|
||||
@@ -88,10 +91,121 @@ func (f *responsesSSEFramer) Flush(w io.Writer) {
|
||||
f.pending = f.pending[:0]
|
||||
return
|
||||
}
|
||||
writeResponsesSSEChunk(w, f.pending)
|
||||
f.writeFrame(w, f.pending)
|
||||
f.pending = f.pending[:0]
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) writeFrame(w io.Writer, frame []byte) {
|
||||
writeResponsesSSEChunk(w, f.repairFrame(frame))
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) repairFrame(frame []byte) []byte {
|
||||
payload, ok := responsesSSEDataPayload(frame)
|
||||
if !ok || len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) || !json.Valid(payload) {
|
||||
return frame
|
||||
}
|
||||
|
||||
switch gjson.GetBytes(payload, "type").String() {
|
||||
case "response.output_item.done":
|
||||
f.recordOutputItem(payload)
|
||||
case "response.completed":
|
||||
repaired := f.repairCompletedPayload(payload)
|
||||
if !bytes.Equal(repaired, payload) {
|
||||
return responsesSSEFrameWithData(frame, repaired)
|
||||
}
|
||||
}
|
||||
return frame
|
||||
}
|
||||
|
||||
func responsesSSEDataPayload(frame []byte) ([]byte, bool) {
|
||||
var payload []byte
|
||||
found := false
|
||||
for _, line := range bytes.Split(frame, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(trimmed[len("data:"):])
|
||||
if found {
|
||||
payload = append(payload, '\n')
|
||||
}
|
||||
payload = append(payload, data...)
|
||||
found = true
|
||||
}
|
||||
return payload, found
|
||||
}
|
||||
|
||||
func responsesSSEFrameWithData(frame, payload []byte) []byte {
|
||||
var out bytes.Buffer
|
||||
for _, line := range bytes.Split(frame, []byte("\n")) {
|
||||
line = bytes.TrimRight(line, "\r")
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
out.Write(line)
|
||||
out.WriteByte('\n')
|
||||
}
|
||||
out.WriteString("data: ")
|
||||
out.Write(payload)
|
||||
out.WriteString("\n\n")
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) recordOutputItem(payload []byte) {
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if !item.Exists() || !item.IsObject() || item.Get("type").String() == "" {
|
||||
return
|
||||
}
|
||||
|
||||
index := len(f.outputOrder)
|
||||
if outputIndex := gjson.GetBytes(payload, "output_index"); outputIndex.Exists() {
|
||||
index = int(outputIndex.Int())
|
||||
}
|
||||
if f.outputItems == nil {
|
||||
f.outputItems = make(map[int][]byte)
|
||||
}
|
||||
if _, exists := f.outputItems[index]; !exists {
|
||||
f.outputOrder = append(f.outputOrder, index)
|
||||
}
|
||||
f.outputItems[index] = append([]byte(nil), item.Raw...)
|
||||
}
|
||||
|
||||
func (f *responsesSSEFramer) repairCompletedPayload(payload []byte) []byte {
|
||||
if len(f.outputOrder) == 0 {
|
||||
return payload
|
||||
}
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if output.Exists() && (!output.IsArray() || len(output.Array()) > 0) {
|
||||
return payload
|
||||
}
|
||||
|
||||
var outputJSON bytes.Buffer
|
||||
outputJSON.WriteByte('[')
|
||||
indexes := append([]int(nil), f.outputOrder...)
|
||||
sort.Ints(indexes)
|
||||
written := 0
|
||||
for _, index := range indexes {
|
||||
item, ok := f.outputItems[index]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if written > 0 {
|
||||
outputJSON.WriteByte(',')
|
||||
}
|
||||
outputJSON.Write(item)
|
||||
written++
|
||||
}
|
||||
outputJSON.WriteByte(']')
|
||||
|
||||
repaired, err := sjson.SetRawBytes(payload, "response.output", outputJSON.Bytes())
|
||||
if err != nil {
|
||||
return payload
|
||||
}
|
||||
return repaired
|
||||
}
|
||||
|
||||
func responsesSSEFrameLen(chunk []byte) int {
|
||||
if len(chunk) == 0 {
|
||||
return 0
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
|
||||
@@ -53,12 +54,43 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
|
||||
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
|
||||
}
|
||||
|
||||
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
|
||||
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"function_call\",\"arguments\":\"{}\"}]}}"
|
||||
if parts[1] != expectedPart2 {
|
||||
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamRepairsEmptyCompletedOutputFromDoneItems(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
data := make(chan []byte, 3)
|
||||
errs := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte(`data: {"type":"response.output_item.done","output_index":0,"item":{"type":"reasoning","id":"rs-1","summary":[]}}`)
|
||||
data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{\"cmd\":\"pwd\"}","status":"completed"}}`)
|
||||
data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
|
||||
close(data)
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
|
||||
|
||||
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(parts[2], "data: ")
|
||||
output := gjson.Get(payload, "response.output")
|
||||
if !output.IsArray() || len(output.Array()) != 2 {
|
||||
t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
|
||||
}
|
||||
if got := gjson.Get(payload, "response.output.1.name").String(); got != "shell" {
|
||||
t.Fatalf("expected function_call name to be preserved, got %q in %s", got, payload)
|
||||
}
|
||||
if got := gjson.Get(payload, "response.output.1.arguments").String(); got != `{"cmd":"pwd"}` {
|
||||
t.Fatalf("expected function_call arguments to be preserved, got %q in %s", got, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
|
||||
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user