mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-05-23 05:00:13 +08:00
Merge branch 'main' into plus
This commit is contained in:
@@ -190,10 +190,11 @@ type BaseAPIHandler struct {
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
h := &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// UpdateClients updates the handlers' client list and configuration.
|
||||
|
||||
37
sdk/api/handlers/openai/endpoint_compat.go
Normal file
37
sdk/api/handlers/openai/endpoint_compat.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package openai
|
||||
|
||||
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
|
||||
const (
|
||||
openAIChatEndpoint = "/chat/completions"
|
||||
openAIResponsesEndpoint = "/responses"
|
||||
)
|
||||
|
||||
func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool) {
|
||||
if modelName == "" {
|
||||
return "", false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(modelName)
|
||||
if info == nil || len(info.SupportedEndpoints) == 0 {
|
||||
return "", false
|
||||
}
|
||||
if endpointListContains(info.SupportedEndpoints, requestedEndpoint) {
|
||||
return "", false
|
||||
}
|
||||
if requestedEndpoint == openAIChatEndpoint && endpointListContains(info.SupportedEndpoints, openAIResponsesEndpoint) {
|
||||
return openAIResponsesEndpoint, true
|
||||
}
|
||||
if requestedEndpoint == openAIResponsesEndpoint && endpointListContains(info.SupportedEndpoints, openAIChatEndpoint) {
|
||||
return openAIChatEndpoint, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func endpointListContains(items []string, value string) bool {
|
||||
for _, item := range items {
|
||||
if item == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
codexconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/openai/chat-completions"
|
||||
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -112,6 +113,23 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
stream := streamResult.Type == gjson.True
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIChatEndpoint); ok && overrideEndpoint == openAIResponsesEndpoint {
|
||||
originalChat := rawJSON
|
||||
if shouldTreatAsResponsesFormat(rawJSON) {
|
||||
// Already responses-style payload; no conversion needed.
|
||||
} else {
|
||||
rawJSON = codexconverter.ConvertOpenAIRequestToCodex(modelName, rawJSON, stream)
|
||||
}
|
||||
stream = gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
if stream {
|
||||
h.handleStreamingResponseViaResponses(c, rawJSON, originalChat)
|
||||
} else {
|
||||
h.handleNonStreamingResponseViaResponses(c, rawJSON, originalChat)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
|
||||
// Convert them to Chat Completions so downstream translators preserve tool metadata.
|
||||
if shouldTreatAsResponsesFormat(rawJSON) {
|
||||
@@ -245,6 +263,76 @@ func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func convertResponsesObjectToChatCompletion(ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, responsesPayload []byte) []byte {
|
||||
if len(responsesPayload) == 0 {
|
||||
return nil
|
||||
}
|
||||
wrapped := wrapResponsesPayloadAsCompleted(responsesPayload)
|
||||
if len(wrapped) == 0 {
|
||||
return nil
|
||||
}
|
||||
var param any
|
||||
converted := codexconverter.ConvertCodexResponseToOpenAINonStream(ctx, modelName, originalChatJSON, responsesRequestJSON, wrapped, ¶m)
|
||||
if converted == "" {
|
||||
return nil
|
||||
}
|
||||
return []byte(converted)
|
||||
}
|
||||
|
||||
func wrapResponsesPayloadAsCompleted(payload []byte) []byte {
|
||||
if gjson.GetBytes(payload, "type").Exists() {
|
||||
return payload
|
||||
}
|
||||
if gjson.GetBytes(payload, "object").String() != "response" {
|
||||
return payload
|
||||
}
|
||||
wrapped := `{"type":"response.completed","response":{}}`
|
||||
wrapped, _ = sjson.SetRaw(wrapped, "response", string(payload))
|
||||
return []byte(wrapped)
|
||||
}
|
||||
|
||||
func writeConvertedResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, chunk []byte, param *any) {
|
||||
outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) forwardResponsesAsChatStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON []byte, param *any) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
|
||||
}
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
||||
// This ensures the completions endpoint returns data in the expected format.
|
||||
//
|
||||
@@ -435,6 +523,30 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
||||
if converted == nil {
|
||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to convert response to chat completion format"),
|
||||
})
|
||||
cliCancel(fmt.Errorf("response conversion failed"))
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(converted)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
@@ -509,6 +621,67 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||
var param any
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek for first usable chunk
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
||||
flusher.Flush()
|
||||
|
||||
h.forwardResponsesAsChatStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalChatJSON, rawJSON, ¶m)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||
// It converts completions request to chat completions format, sends to backend,
|
||||
// then converts the response back to completions format before sending to client.
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -83,7 +84,21 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
stream := streamResult.Type == gjson.True
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIResponsesEndpoint); ok && overrideEndpoint == openAIChatEndpoint {
|
||||
chatJSON := responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
|
||||
stream = gjson.GetBytes(chatJSON, "stream").Bool()
|
||||
if stream {
|
||||
h.handleStreamingResponseViaChat(c, rawJSON, chatJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponseViaChat(c, rawJSON, chatJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponse(c, rawJSON)
|
||||
@@ -116,6 +131,31 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
var param any
|
||||
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
||||
if converted == "" {
|
||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to convert chat completion response to responses format"),
|
||||
})
|
||||
cliCancel(fmt.Errorf("response conversion failed"))
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte(converted))
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
@@ -196,6 +236,116 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||
var param any
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
||||
flusher.Flush()
|
||||
|
||||
h.forwardChatAsResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalResponsesJSON, ¶m)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeChatAsResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalResponsesJSON, chunk []byte, param *any) {
|
||||
outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix([]byte(out), []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte(out))
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalResponsesJSON []byte, param *any) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if out == "" {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix([]byte(out), []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte(out))
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
|
||||
129
sdk/auth/github_copilot.go
Normal file
129
sdk/auth/github_copilot.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot.
|
||||
type GitHubCopilotAuthenticator struct{}
|
||||
|
||||
// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator.
|
||||
func NewGitHubCopilotAuthenticator() Authenticator {
|
||||
return &GitHubCopilotAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for github-copilot.
|
||||
func (GitHubCopilotAuthenticator) Provider() string {
|
||||
return "github-copilot"
|
||||
}
|
||||
|
||||
// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense.
|
||||
// The token remains valid until the user revokes it or the Copilot subscription expires.
|
||||
func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login initiates the GitHub device flow authentication for Copilot access.
|
||||
func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
authSvc := copilot.NewCopilotAuth(cfg)
|
||||
|
||||
// Start the device flow
|
||||
fmt.Println("Starting GitHub Copilot authentication...")
|
||||
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err)
|
||||
}
|
||||
|
||||
// Display the user code and verification URL
|
||||
fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI)
|
||||
fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode)
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for GitHub authorization...")
|
||||
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||
|
||||
// Wait for user authorization
|
||||
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||
if err != nil {
|
||||
errMsg := copilot.GetUserFriendlyMessage(err)
|
||||
return nil, fmt.Errorf("github-copilot: %s", errMsg)
|
||||
}
|
||||
|
||||
// Verify the token can get a Copilot API token
|
||||
fmt.Println("Verifying Copilot access...")
|
||||
apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err)
|
||||
}
|
||||
|
||||
// Create the token storage
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
// Build metadata with token information for the executor
|
||||
metadata := map[string]any{
|
||||
"type": "github-copilot",
|
||||
"username": authBundle.Username,
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if apiToken.ExpiresAt > 0 {
|
||||
metadata["api_token_expires_at"] = apiToken.ExpiresAt
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
||||
|
||||
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: authBundle.Username,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshGitHubCopilotToken validates and returns the current token status.
|
||||
// GitHub OAuth tokens don't need traditional refresh - we just validate they still work.
|
||||
func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error {
|
||||
if storage == nil || storage.AccessToken == "" {
|
||||
return fmt.Errorf("no token available")
|
||||
}
|
||||
|
||||
authSvc := copilot.NewCopilotAuth(cfg)
|
||||
|
||||
// Validate the token can still get a Copilot API token
|
||||
_, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
470
sdk/auth/kiro.go
Normal file
470
sdk/auth/kiro.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||
// Returns account name if provided, otherwise profile ARN ID.
|
||||
// All extracted values are sanitized to prevent path injection attacks.
|
||||
func extractKiroIdentifier(accountName, profileArn string) string {
|
||||
// Priority 1: Use account name if provided
|
||||
if accountName != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||
}
|
||||
|
||||
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
|
||||
if profileArn != "" {
|
||||
parts := strings.Split(profileArn, "/")
|
||||
if len(parts) >= 2 {
|
||||
// Sanitize the ARN component to prevent path traversal
|
||||
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: timestamp
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
|
||||
type KiroAuthenticator struct{}
|
||||
|
||||
// NewKiroAuthenticator constructs a Kiro authenticator.
|
||||
func NewKiroAuthenticator() *KiroAuthenticator {
|
||||
return &KiroAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for the authenticator.
|
||||
func (a *KiroAuthenticator) Provider() string {
|
||||
return "kiro"
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 5 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
// createAuthRecord creates an auth record from token data.
|
||||
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
// Determine label based on auth method
|
||||
label := fmt.Sprintf("kiro-%s", source)
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
label = "kiro-idc"
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific fields if present
|
||||
if tokenData.StartURL != "" {
|
||||
metadata["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
metadata["region"] = tokenData.Region
|
||||
}
|
||||
|
||||
attributes := map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": source,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific attributes if present
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
attributes["source"] = "aws-idc"
|
||||
if tokenData.StartURL != "" {
|
||||
attributes["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
attributes["region"] = tokenData.Region
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: metadata,
|
||||
Attributes: attributes,
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||
// This shows a method selection prompt and handles both flows.
|
||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
// Use the unified method selection flow (Builder ID or IDC)
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
return a.createAuthRecord(tokenData, "aws")
|
||||
}
|
||||
|
||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use AWS Builder ID authorization code flow
|
||||
tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-aws",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "aws-builder-id-authcode",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use Google OAuth flow with protocol handler
|
||||
tokenData, err := oauth.LoginWithGoogle(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-google-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-google",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "google-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro Google authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use GitHub OAuth flow with protocol handler
|
||||
tokenData, err := oauth.LoginWithGitHub(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-github-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-github",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "github-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro GitHub authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
|
||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract email from JWT if not already set (for imported tokens)
|
||||
if tokenData.Email == "" {
|
||||
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||
if provider == "" {
|
||||
provider = "imported" // Fallback for legacy tokens without provider
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: fmt.Sprintf("kiro-%s", provider),
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
|
||||
} else {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
|
||||
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil, fmt.Errorf("invalid auth record")
|
||||
}
|
||||
|
||||
refreshToken, ok := auth.Metadata["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token not found")
|
||||
}
|
||||
|
||||
clientID, _ := auth.Metadata["client_id"].(string)
|
||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
startURL, _ := auth.Metadata["start_url"].(string)
|
||||
region, _ := auth.Metadata["region"].(string)
|
||||
|
||||
var tokenData *kiroauth.KiroTokenData
|
||||
var err error
|
||||
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
|
||||
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||
switch {
|
||||
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||
default:
|
||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Clone auth to avoid mutating the input parameter
|
||||
updated := auth.Clone()
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
updated.Metadata["access_token"] = tokenData.AccessToken
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
|
||||
}
|
||||
return record, savedPath, nil
|
||||
}
|
||||
|
||||
// SaveAuth persists an auth record directly without going through the login flow.
|
||||
func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
|
||||
if m.store == nil {
|
||||
return "", fmt.Errorf("no store configured")
|
||||
}
|
||||
if cfg != nil {
|
||||
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
}
|
||||
return m.store.Save(context.Background(), record)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ func init() {
|
||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -49,7 +49,7 @@ type RefreshEvaluator interface {
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
refreshFailureBackoff = 1 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
quotaBackoffMax = 30 * time.Minute
|
||||
)
|
||||
@@ -2157,7 +2157,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
updated.Runtime = auth.Runtime
|
||||
}
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
// Preserve NextRefreshAfter set by the Authenticator
|
||||
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
|
||||
@@ -221,7 +221,7 @@ func modelAliasChannel(auth *Auth) string {
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model alias (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
@@ -245,7 +245,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
|
||||
@@ -43,6 +43,15 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
input: "gemini-2.5-pro",
|
||||
want: "gemini-2.5-pro-exp-03-25",
|
||||
},
|
||||
{
|
||||
name: "kiro alias resolves",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"kiro": {{Name: "kiro-claude-sonnet-4-5", Alias: "sonnet"}},
|
||||
},
|
||||
channel: "kiro",
|
||||
input: "sonnet",
|
||||
want: "kiro-claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "config suffix takes priority",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
@@ -152,6 +161,8 @@ func createAuthForChannel(channel string) *Auth {
|
||||
return &Auth{Provider: "qwen"}
|
||||
case "iflow":
|
||||
return &Auth{Provider: "iflow"}
|
||||
case "kiro":
|
||||
return &Auth{Provider: "kiro"}
|
||||
default:
|
||||
return &Auth{Provider: channel}
|
||||
}
|
||||
|
||||
@@ -227,6 +227,18 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
}
|
||||
}
|
||||
|
||||
// For GitHub provider, return username
|
||||
if strings.ToLower(a.Provider) == "github" {
|
||||
if a.Metadata != nil {
|
||||
if username, ok := a.Metadata["username"].(string); ok {
|
||||
username = strings.TrimSpace(username)
|
||||
if username != "" {
|
||||
return "oauth", username
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check metadata for email first (OAuth-style auth)
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
|
||||
@@ -379,6 +379,10 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||
case "iflow":
|
||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||
case "kiro":
|
||||
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
@@ -767,6 +771,12 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "iflow":
|
||||
models = registry.GetIFlowModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "github-copilot":
|
||||
models = registry.GetGitHubCopilotModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kiro":
|
||||
models = registry.GetKiroModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
|
||||
Reference in New Issue
Block a user