mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-21 18:12:54 +08:00
Merge pull request #3847 from router-for-me/home
Fix credential handling in home models
This commit is contained in:
@@ -1157,7 +1157,7 @@ func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
raw, errGet := client.GetModels(c.Request.Context())
|
||||
raw, errGet := client.GetModels(c.Request.Context(), c.Request.Header, c.Request.URL.Query())
|
||||
if errGet != nil {
|
||||
c.JSON(http.StatusBadGateway, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
@@ -1168,6 +1168,16 @@ func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if statusCode, ok := homeModelsAuthStatus(raw); ok {
|
||||
c.JSON(statusCode, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: homeModelsErrorMessage(raw),
|
||||
Type: "authentication_error",
|
||||
},
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
|
||||
entries, errDecode := decodeHomeModels(raw)
|
||||
if errDecode != nil {
|
||||
c.JSON(http.StatusBadGateway, handlers.ErrorResponse{
|
||||
@@ -1217,6 +1227,70 @@ func homeGeminiModelMatches(entry homeModelEntry, action string) bool {
|
||||
return action == id || action == "models/"+id || normalizedAction == normalizedID
|
||||
}
|
||||
|
||||
// homeModelsAuthStatus inspects a home models response for an authentication/error envelope.
|
||||
// It returns the HTTP status code to surface (401 for credential issues, 502 otherwise)
|
||||
// and true when the payload is an error response rather than model data.
|
||||
func homeModelsAuthStatus(raw []byte) (int, bool) {
|
||||
errType := homeModelsErrorType(raw)
|
||||
if errType == "" {
|
||||
return 0, false
|
||||
}
|
||||
if errType == "no_credentials" || errType == "invalid_credential" {
|
||||
return http.StatusUnauthorized, true
|
||||
}
|
||||
return http.StatusBadGateway, true
|
||||
}
|
||||
|
||||
func homeModelsErrorType(raw []byte) string {
|
||||
top, ok := unmarshalHomeModelsTopLevel(raw)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
rawErr, exists := top["error"]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
var errObj struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(rawErr, &errObj); errUnmarshal != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(errObj.Type)
|
||||
}
|
||||
|
||||
func homeModelsErrorMessage(raw []byte) string {
|
||||
top, ok := unmarshalHomeModelsTopLevel(raw)
|
||||
if !ok {
|
||||
return "home models request failed"
|
||||
}
|
||||
rawErr, exists := top["error"]
|
||||
if !exists {
|
||||
return "home models request failed"
|
||||
}
|
||||
var errObj struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(rawErr, &errObj); errUnmarshal != nil {
|
||||
return "home models request failed"
|
||||
}
|
||||
if msg := strings.TrimSpace(errObj.Message); msg != "" {
|
||||
return msg
|
||||
}
|
||||
return "home models request failed"
|
||||
}
|
||||
|
||||
func unmarshalHomeModelsTopLevel(raw []byte) (map[string]json.RawMessage, bool) {
|
||||
if len(raw) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var top map[string]json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal(raw, &top); errUnmarshal != nil {
|
||||
return nil, false
|
||||
}
|
||||
return top, true
|
||||
}
|
||||
|
||||
func decodeHomeModels(raw []byte) ([]homeModelEntry, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, fmt.Errorf("home models payload is empty")
|
||||
|
||||
@@ -551,3 +551,39 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHomeModelsAuthStatus(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
wantStatus int
|
||||
wantHandled bool
|
||||
}{
|
||||
{"no credentials", `{"error":{"type":"no_credentials","message":"Missing API key"}}`, http.StatusUnauthorized, true},
|
||||
{"invalid credential", `{"error":{"type":"invalid_credential","message":"Invalid API key"}}`, http.StatusUnauthorized, true},
|
||||
{"internal error maps to bad gateway", `{"error":{"type":"internal_error","message":"boom"}}`, http.StatusBadGateway, true},
|
||||
{"models payload not an error", `{"openai":[{"id":"gpt-5.5"}]}`, 0, false},
|
||||
{"empty payload not an error", `{}`, 0, false},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
status, handled := homeModelsAuthStatus([]byte(tc.raw))
|
||||
if handled != tc.wantHandled {
|
||||
t.Fatalf("handled = %v, want %v (status=%d)", handled, tc.wantHandled, status)
|
||||
}
|
||||
if handled && status != tc.wantStatus {
|
||||
t.Fatalf("status = %d, want %d", status, tc.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHomeModelsErrorMessage(t *testing.T) {
|
||||
if msg := homeModelsErrorMessage([]byte(`{"error":{"type":"invalid_credential","message":"Invalid API key"}}`)); msg != "Invalid API key" {
|
||||
t.Fatalf("message = %q, want %q", msg, "Invalid API key")
|
||||
}
|
||||
if msg := homeModelsErrorMessage([]byte(`{"openai":[]}`)); msg != "home models request failed" {
|
||||
t.Fatalf("default message = %q, want fallback", msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -25,7 +26,6 @@ import (
|
||||
const (
|
||||
redisKeyConfig = "config"
|
||||
redisChannelConfig = "config"
|
||||
redisKeyModels = "models"
|
||||
redisKeyUsage = "usage"
|
||||
redisKeyRequestLog = "request-log"
|
||||
redisKeyAppLog = "app-log"
|
||||
@@ -520,12 +520,21 @@ func (c *Client) GetConfig(ctx context.Context) ([]byte, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetModels(ctx context.Context) ([]byte, error) {
|
||||
func (c *Client) GetModels(ctx context.Context, headers http.Header, query url.Values) ([]byte, error) {
|
||||
cmd, errClient := c.commandClient()
|
||||
if errClient != nil {
|
||||
return nil, errClient
|
||||
}
|
||||
raw, err := cmd.Get(ctx, redisKeyModels).Bytes()
|
||||
req := modelsRequest{
|
||||
Type: "models",
|
||||
Headers: headersToLowerMap(headers),
|
||||
Query: queryToLowerMap(query),
|
||||
}
|
||||
keyBytes, err := json.Marshal(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := cmd.Get(ctx, string(keyBytes)).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrModelsNotFound
|
||||
}
|
||||
@@ -745,6 +754,32 @@ func headersToLowerMap(headers http.Header) map[string]string {
|
||||
return out
|
||||
}
|
||||
|
||||
func queryToLowerMap(query url.Values) map[string]string {
|
||||
if len(query) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(query))
|
||||
for key, values := range query {
|
||||
k := strings.ToLower(strings.TrimSpace(key))
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
if len(values) == 0 {
|
||||
out[k] = ""
|
||||
continue
|
||||
}
|
||||
trimmed := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
trimmed = append(trimmed, strings.TrimSpace(v))
|
||||
}
|
||||
out[k] = strings.Join(trimmed, ", ")
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func newAuthDispatchRequest(requestedModel string, sessionID string, headers http.Header, count int) authDispatchRequest {
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -399,3 +400,75 @@ func readRedisCommand(reader *bufio.Reader) ([]string, error) {
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func TestModelsRequestSerializationCarriesCredentials(t *testing.T) {
|
||||
req := modelsRequest{
|
||||
Type: "models",
|
||||
Headers: headersToLowerMap(http.Header{"Authorization": {"Bearer test-key"}}),
|
||||
Query: queryToLowerMap(url.Values{"key": {"gemini-key"}}),
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(&req)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal models request: %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
t.Fatalf("unmarshal models request: %v", err)
|
||||
}
|
||||
if payload["type"] != "models" {
|
||||
t.Fatalf("type = %v, want models", payload["type"])
|
||||
}
|
||||
headers, ok := payload["headers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("headers missing or wrong type: %v", payload["headers"])
|
||||
}
|
||||
if headers["authorization"] != "Bearer test-key" {
|
||||
t.Fatalf("headers.authorization = %v, want Bearer test-key", headers["authorization"])
|
||||
}
|
||||
query, ok := payload["query"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("query missing or wrong type: %v", payload["query"])
|
||||
}
|
||||
if query["key"] != "gemini-key" {
|
||||
t.Fatalf("query.key = %v, want gemini-key", query["key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelsRequestOmitsEmptyCredentials(t *testing.T) {
|
||||
req := modelsRequest{Type: "models"}
|
||||
|
||||
raw, err := json.Marshal(&req)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal models request: %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
t.Fatalf("unmarshal models request: %v", err)
|
||||
}
|
||||
if _, exists := payload["headers"]; exists {
|
||||
t.Fatalf("headers should be omitted when empty, got %v", payload["headers"])
|
||||
}
|
||||
if _, exists := payload["query"]; exists {
|
||||
t.Fatalf("query should be omitted when empty, got %v", payload["query"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryToLowerMap(t *testing.T) {
|
||||
got := queryToLowerMap(url.Values{
|
||||
"Key": {"v1", "v2"},
|
||||
"Token": {"abc"},
|
||||
})
|
||||
if got["key"] != "v1, v2" {
|
||||
t.Fatalf("key = %q, want %q", got["key"], "v1, v2")
|
||||
}
|
||||
if got["token"] != "abc" {
|
||||
t.Fatalf("token = %q, want %q", got["token"], "abc")
|
||||
}
|
||||
|
||||
if nilMap := queryToLowerMap(nil); nilMap != nil {
|
||||
t.Fatalf("queryToLowerMap(nil) = %v, want nil", nilMap)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,12 @@ type authDispatchRequest struct {
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
type modelsRequest struct {
|
||||
Type string `json:"type"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Query map[string]string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
type refreshRequest struct {
|
||||
Type string `json:"type"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
|
||||
@@ -4139,7 +4139,7 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro
|
||||
switch strings.ToLower(code) {
|
||||
case "model_not_found":
|
||||
status = http.StatusNotFound
|
||||
case "authentication_error", "unauthorized":
|
||||
case "authentication_error", "unauthorized", "no_credentials", "invalid_credential":
|
||||
status = http.StatusUnauthorized
|
||||
}
|
||||
return nil, nil, "", &Error{Code: code, Message: msg, HTTPStatus: status}
|
||||
|
||||
Reference in New Issue
Block a user