diff --git a/internal/api/server.go b/internal/api/server.go index 812724c27..492061a47 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -409,7 +409,7 @@ func (s *Server) setupRoutes() { { v1beta.GET("/models", s.geminiModelsHandler(geminiHandlers)) v1beta.POST("/models/*action", geminiHandlers.GeminiHandler) - v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler) + v1beta.GET("/models/*action", s.geminiGetHandler(geminiHandlers)) } // Root endpoint @@ -851,6 +851,17 @@ func (s *Server) geminiModelsHandler(geminiHandler *gemini.GeminiAPIHandler) gin } } +func (s *Server) geminiGetHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc { + return func(c *gin.Context) { + if s != nil && s.cfg != nil && s.cfg.Home.Enabled { + s.handleHomeGeminiModel(c) + return + } + + geminiHandler.GeminiGetHandler(c) + } +} + type homeModelEntry struct { id string created int64 @@ -933,6 +944,29 @@ func (s *Server) handleHomeGeminiModels(c *gin.Context) { }) } +func (s *Server) handleHomeGeminiModel(c *gin.Context) { + entries, ok := s.loadHomeModelEntries(c) + if !ok { + return + } + + action := strings.TrimPrefix(c.Param("action"), "/") + action = strings.TrimSpace(action) + for _, entry := range entries { + if homeGeminiModelMatches(entry, action) { + c.JSON(http.StatusOK, formatHomeGeminiModel(entry)) + return + } + } + + c.JSON(http.StatusNotFound, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Not Found", + Type: "not_found", + }, + }) +} + func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) { if s == nil || c == nil || c.Request == nil { return nil, false @@ -976,24 +1010,38 @@ func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) { func formatHomeGeminiModels(entries []homeModelEntry) []map[string]any { out := make([]map[string]any, 0, len(entries)) for _, entry := range entries { - name := entry.id - if !strings.HasPrefix(name, "models/") { - name = "models/" + name - } - displayName := entry.displayName - if displayName == "" { - displayName = entry.id - } - out = append(out, map[string]any{ - "name": name, - "displayName": displayName, - "description": displayName, - "supportedGenerationMethods": []string{"generateContent"}, - }) + out = append(out, formatHomeGeminiModel(entry)) } return out } +func formatHomeGeminiModel(entry homeModelEntry) map[string]any { + name := entry.id + if !strings.HasPrefix(name, "models/") { + name = "models/" + name + } + displayName := entry.displayName + if displayName == "" { + displayName = entry.id + } + return map[string]any{ + "name": name, + "displayName": displayName, + "description": displayName, + "supportedGenerationMethods": []string{"generateContent"}, + } +} + +func homeGeminiModelMatches(entry homeModelEntry, action string) bool { + id := strings.TrimSpace(entry.id) + if id == "" || action == "" { + return false + } + normalizedAction := strings.TrimPrefix(action, "models/") + normalizedID := strings.TrimPrefix(id, "models/") + return action == id || action == "models/"+id || normalizedAction == normalizedID +} + func decodeHomeModels(raw []byte) ([]homeModelEntry, error) { if len(raw) == 0 { return nil, fmt.Errorf("home models payload is empty") diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index d44809b0c..fca26a9c2 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -3231,6 +3231,79 @@ func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) { ginCtx.Set("userApiKey", apiKey) } +func homeDispatchHeaders(ctx context.Context, headers http.Header) http.Header { + apiKey, ok := homeQueryCredentialFromContext(ctx) + if !ok { + return headers + } + out := headers.Clone() + if out == nil { + out = http.Header{} + } + if out.Get("Authorization") != "" || out.Get("X-Goog-Api-Key") != "" || out.Get("X-Api-Key") != "" { + return out + } + out.Set("X-Goog-Api-Key", apiKey) + return out +} + +func homeQueryCredentialFromContext(ctx context.Context) (string, bool) { + if ctx == nil { + return "", false + } + if queryCtx, ok := ctx.Value("gin").(interface{ Query(string) string }); ok && queryCtx != nil { + if apiKey := strings.TrimSpace(queryCtx.Query("key")); apiKey != "" { + return apiKey, true + } + if apiKey := strings.TrimSpace(queryCtx.Query("auth_token")); apiKey != "" { + return apiKey, true + } + } + ginCtx, ok := ctx.Value("gin").(interface{ Get(string) (any, bool) }) + if !ok || ginCtx == nil { + return "", false + } + rawMetadata, ok := ginCtx.Get("accessMetadata") + if !ok { + return "", false + } + source := accessMetadataSource(rawMetadata) + if source != "query-key" && source != "query-auth-token" { + return "", false + } + rawAPIKey, ok := ginCtx.Get("userApiKey") + if !ok { + return "", false + } + apiKey := contextStringValue(rawAPIKey) + if apiKey == "" { + return "", false + } + return apiKey, true +} + +func accessMetadataSource(raw any) string { + switch v := raw.(type) { + case map[string]string: + return strings.TrimSpace(v["source"]) + case map[string]any: + return contextStringValue(v["source"]) + default: + return "" + } +} + +func contextStringValue(raw any) string { + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + func homeExecutionSessionIDFromMetadata(meta map[string]any) string { if len(meta) == 0 { return "" @@ -3352,8 +3425,9 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro requestedModel := requestedModelFromMetadata(opts.Metadata, model) sessionID := ExtractSessionID(opts.Headers, opts.OriginalRequest, opts.Metadata) + dispatchHeaders := homeDispatchHeaders(ctx, opts.Headers) - raw, err := client.RPopAuth(ctx, requestedModel, sessionID, opts.Headers, count) + raw, err := client.RPopAuth(ctx, requestedModel, sessionID, dispatchHeaders, count) if err != nil { return nil, nil, "", &Error{Code: "auth_not_found", Message: err.Error(), HTTPStatus: http.StatusServiceUnavailable} } diff --git a/sdk/cliproxy/auth/home_dispatch_headers_test.go b/sdk/cliproxy/auth/home_dispatch_headers_test.go new file mode 100644 index 000000000..b4aef310d --- /dev/null +++ b/sdk/cliproxy/auth/home_dispatch_headers_test.go @@ -0,0 +1,87 @@ +package auth + +import ( + "context" + "net/http" + "testing" +) + +type homeDispatchTestGinContext struct { + values map[string]any + query map[string]string +} + +func (c homeDispatchTestGinContext) Get(key string) (any, bool) { + v, ok := c.values[key] + return v, ok +} + +func (c homeDispatchTestGinContext) Query(key string) string { + if c.query == nil { + return "" + } + return c.query[key] +} + +func TestHomeDispatchHeadersAddsQueryKeyCredential(t *testing.T) { + ginCtx := homeDispatchTestGinContext{query: map[string]string{"key": "12345"}} + ctx := context.WithValue(context.Background(), "gin", ginCtx) + headers := http.Header{"User-Agent": {"client"}} + + got := homeDispatchHeaders(ctx, headers) + + if got.Get("X-Goog-Api-Key") != "12345" { + t.Fatalf("X-Goog-Api-Key = %q, want %q", got.Get("X-Goog-Api-Key"), "12345") + } + if headers.Get("X-Goog-Api-Key") != "" { + t.Fatalf("original headers were mutated: %v", headers) + } +} + +func TestHomeDispatchHeadersAddsQueryCredentialFromAccessMetadata(t *testing.T) { + ginCtx := homeDispatchTestGinContext{values: map[string]any{ + "accessMetadata": map[string]string{"source": "query-key"}, + "userApiKey": "12345", + }} + ctx := context.WithValue(context.Background(), "gin", ginCtx) + headers := http.Header{"User-Agent": {"client"}} + + got := homeDispatchHeaders(ctx, headers) + + if got.Get("X-Goog-Api-Key") != "12345" { + t.Fatalf("X-Goog-Api-Key = %q, want %q", got.Get("X-Goog-Api-Key"), "12345") + } + if headers.Get("X-Goog-Api-Key") != "" { + t.Fatalf("original headers were mutated: %v", headers) + } +} + +func TestHomeDispatchHeadersKeepsExistingCredentialHeader(t *testing.T) { + ginCtx := homeDispatchTestGinContext{query: map[string]string{"key": "query-key"}} + ctx := context.WithValue(context.Background(), "gin", ginCtx) + headers := http.Header{"X-Goog-Api-Key": {"header-key"}} + + got := homeDispatchHeaders(ctx, headers) + + if got.Get("X-Goog-Api-Key") != "header-key" { + t.Fatalf("X-Goog-Api-Key = %q, want %q", got.Get("X-Goog-Api-Key"), "header-key") + } +} + +func TestHomeDispatchHeadersIgnoresHeaderCredentialSource(t *testing.T) { + ginCtx := homeDispatchTestGinContext{values: map[string]any{ + "accessMetadata": map[string]string{"source": "authorization"}, + "userApiKey": "12345", + }} + ctx := context.WithValue(context.Background(), "gin", ginCtx) + headers := http.Header{"Authorization": {"Bearer 12345"}} + + got := homeDispatchHeaders(ctx, headers) + + if got.Get("X-Goog-Api-Key") != "" { + t.Fatalf("X-Goog-Api-Key = %q, want empty", got.Get("X-Goog-Api-Key")) + } + if got.Get("Authorization") != "Bearer 12345" { + t.Fatalf("Authorization = %q, want %q", got.Get("Authorization"), "Bearer 12345") + } +}