Align Codex websocket protocol semantics

This commit is contained in:
Kenny
2026-05-03 15:56:39 -07:00
parent 672fdd14ed
commit c19ae1d5be
3 changed files with 310 additions and 34 deletions

View File

@@ -188,7 +188,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body = normalizeCodexInstructions(body)
@@ -776,6 +775,11 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
default:
return "", fmt.Errorf("codex websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme)
}
if strings.TrimSpace(parsed.Host) == "" {
return "", fmt.Errorf("codex websockets executor: responses websocket URL host is empty")
}
return parsed.String(), nil
}
@@ -809,6 +813,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
setHeaderCasePreserved(headers, "session_id", cache.ID)
headers.Set("Conversation_id", cache.ID)
}
@@ -828,13 +833,19 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header.Clone()
}
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
isAPIKey := codexAuthUsesAPIKey(auth)
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
if isAPIKey {
ensureHeaderWithPriority(headers, ginHeaders, "User-Agent", "", "")
} else {
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
}
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
@@ -845,16 +856,9 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
headers.Set("OpenAI-Beta", betaHeader)
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
}
headers.Del("User-Agent")
isAPIKey := false
if auth != nil && auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
isAPIKey = true
}
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", uuid.NewString())
}
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", "")
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
@@ -864,7 +868,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
headers.Set("Chatgpt-Account-Id", trimmed)
headers.Set("ChatGPT-Account-ID", trimmed)
}
}
}
@@ -879,6 +883,77 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers
}
func codexAuthUsesAPIKey(auth *cliproxyauth.Auth) bool {
if auth == nil || auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["api_key"]) != ""
}
func ensureHeaderCasePreserved(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(headerValueCaseInsensitive(target, key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(headerValueCaseInsensitive(source, key)); val != "" {
setHeaderCasePreserved(target, key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
setHeaderCasePreserved(target, key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
setHeaderCasePreserved(target, key, val)
}
}
func setHeaderCasePreserved(headers http.Header, key string, value string) {
if headers == nil {
return
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
return
}
deleteHeaderCaseInsensitive(headers, key)
headers[key] = []string{value}
}
func headerValueCaseInsensitive(headers http.Header, key string) string {
key = strings.TrimSpace(key)
if headers == nil || key == "" {
return ""
}
if val := strings.TrimSpace(headers.Get(key)); val != "" {
return val
}
for existingKey, values := range headers {
if !strings.EqualFold(existingKey, key) {
continue
}
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
}
return ""
}
func deleteHeaderCaseInsensitive(headers http.Header, key string) {
for existingKey := range headers {
if strings.EqualFold(existingKey, key) {
delete(headers, existingKey)
}
}
}
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
@@ -962,25 +1037,53 @@ func parseCodexWebsocketError(payload []byte) (error, bool) {
return nil, false
}
out := []byte(`{}`)
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
raw := errNode.Raw
if errNode.Type == gjson.String {
raw = errNode.Raw
}
out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
} else {
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
}
out := buildCodexWebsocketErrorPayload(payload, status)
headers := parseCodexWebsocketErrorHeaders(payload)
statusError := statusErr{code: status, msg: string(out)}
if isCodexWebsocketConnectionLimitError(payload) {
retryAfter := time.Duration(0)
statusError.retryAfter = &retryAfter
}
return statusErrWithHeaders{
statusErr: statusErr{code: status, msg: string(out)},
statusErr: statusError,
headers: headers,
}, true
}
func buildCodexWebsocketErrorPayload(payload []byte, status int) []byte {
out := []byte(`{}`)
out, _ = sjson.SetBytes(out, "status", status)
if bodyNode := gjson.GetBytes(payload, "body"); bodyNode.Exists() {
out, _ = sjson.SetRawBytes(out, "body", []byte(bodyNode.Raw))
if bodyErrorNode := bodyNode.Get("error"); bodyErrorNode.Exists() {
out, _ = sjson.SetRawBytes(out, "error", []byte(bodyErrorNode.Raw))
return out
}
}
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw))
return out
}
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
return out
}
func isCodexWebsocketConnectionLimitError(payload []byte) bool {
if len(payload) == 0 {
return false
}
for _, path := range []string{"error.code", "error.type", "body.error.code", "body.error.type", "code", "error"} {
if strings.TrimSpace(gjson.GetBytes(payload, path).String()) == "websocket_connection_limit_reached" {
return true
}
}
return false
}
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
headersNode := gjson.GetBytes(payload, "headers")
if !headersNode.Exists() || !headersNode.IsObject() {