diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 41a4fde40..490424316 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -36,6 +36,33 @@ import ( "golang.org/x/oauth2/google" ) +var ( + oauthStatus = make(map[string]string) + oauthStatusMutex sync.RWMutex +) + +// getOAuthStatus safely retrieves an OAuth status +func getOAuthStatus(key string) (string, bool) { + oauthStatusMutex.RLock() + defer oauthStatusMutex.RUnlock() + status, ok := oauthStatus[key] + return status, ok +} + +// setOAuthStatus safely sets an OAuth status +func setOAuthStatus(key string, status string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + oauthStatus[key] = status +} + +// deleteOAuthStatus safely deletes an OAuth status +func deleteOAuthStatus(key string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + delete(oauthStatus, key) +} + var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -197,19 +224,6 @@ func stopCallbackForwarder(port int) { stopForwarderInstance(port, forwarder) } -func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) { - if forwarder == nil { - return - } - callbackForwardersMu.Lock() - if current := callbackForwarders[port]; current == forwarder { - delete(callbackForwarders, port) - } - callbackForwardersMu.Unlock() - - stopForwarderInstance(port, forwarder) -} - func stopForwarderInstance(port int, forwarder *callbackForwarder) { if forwarder == nil || forwarder.server == nil { return @@ -795,10 +809,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { return } - RegisterOAuthSession(state, "anthropic") - isWebUI := isWebUIRequest(c) - var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") if errTarget != nil { @@ -806,8 +817,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - var errStart error - if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { + if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start anthropic callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -816,39 +826,40 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder) + defer stopCallbackForwarder(anthropicCallbackPort) } // Helper: wait for callback file waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state)) - waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { - deadline := time.Now().Add(timeout) + waitForFile := func(ctx context.Context, path string, timeout time.Duration) (map[string]string, error) { + timer := time.NewTimer(timeout) + ticker := time.NewTicker(500 * time.Millisecond) + defer timer.Stop() + defer ticker.Stop() for { - if !IsOAuthSessionPending(state, "anthropic") { - return nil, errOAuthSessionNotPending - } - if time.Now().After(deadline) { - SetOAuthSessionError(state, "Timeout waiting for OAuth callback") + select { + case <-ctx.Done(): + setOAuthStatus(state, "OAuth flow canceled") + return nil, ctx.Err() + case <-timer.C: + setOAuthStatus(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") + case <-ticker.C: + data, errRead := os.ReadFile(path) + if errRead == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(path) + return m, nil + } } - data, errRead := os.ReadFile(path) - if errRead == nil { - var m map[string]string - _ = json.Unmarshal(data, &m) - _ = os.Remove(path) - return m, nil - } - time.Sleep(500 * time.Millisecond) } } fmt.Println("Waiting for authentication callback...") // Wait up to 5 minutes - resultMap, errWait := waitForFile(waitFile, 5*time.Minute) + resultMap, errWait := waitForFile(ctx, waitFile, 5*time.Minute) if errWait != nil { - if errors.Is(errWait, errOAuthSessionNotPending) { - return - } authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) log.Error(claude.GetUserFriendlyMessage(authErr)) return @@ -856,13 +867,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - SetOAuthSessionError(state, "Bad request") + setOAuthStatus(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - SetOAuthSessionError(state, "State code error") + setOAuthStatus(state, "State code error") return } @@ -895,7 +906,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") + setOAuthStatus(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -906,7 +917,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) + setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -919,7 +930,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - SetOAuthSessionError(state, "Failed to parse token response") + setOAuthStatus(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -944,7 +955,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - SetOAuthSessionError(state, "Failed to save authentication tokens") + setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -953,10 +964,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - CompleteOAuthSession(state) - CompleteOAuthSessionsByProvider("anthropic") + deleteOAuthStatus(state) }() + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -987,10 +998,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - RegisterOAuthSession(state, "gemini") - isWebUI := isWebUIRequest(c) - var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") if errTarget != nil { @@ -998,8 +1006,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - var errStart error - if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { + if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start gemini callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1008,48 +1015,54 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder) + defer stopCallbackForwarder(geminiCallbackPort) } // Wait for callback file written by server route waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state)) fmt.Println("Waiting for authentication callback...") - deadline := time.Now().Add(5 * time.Minute) + timer := time.NewTimer(5 * time.Minute) + ticker := time.NewTicker(500 * time.Millisecond) + defer timer.Stop() + defer ticker.Stop() var authCode string + waitForCallback: for { - if !IsOAuthSessionPending(state, "gemini") { + select { + case <-ctx.Done(): + log.Error("oauth flow canceled") + setOAuthStatus(state, "OAuth flow canceled") return - } - if time.Now().After(deadline) { + case <-timer.C: log.Error("oauth flow timed out") - SetOAuthSessionError(state, "OAuth flow timed out") + setOAuthStatus(state, "OAuth flow timed out") return - } - if data, errR := os.ReadFile(waitFile); errR == nil { - var m map[string]string - _ = json.Unmarshal(data, &m) - _ = os.Remove(waitFile) - if errStr := m["error"]; errStr != "" { - log.Errorf("Authentication failed: %s", errStr) - SetOAuthSessionError(state, "Authentication failed") - return - } - authCode = m["code"] - if authCode == "" { - log.Errorf("Authentication failed: code not found") - SetOAuthSessionError(state, "Authentication failed: code not found") - return + case <-ticker.C: + if data, errR := os.ReadFile(waitFile); errR == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(waitFile) + if errStr := m["error"]; errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + setOAuthStatus(state, "Authentication failed") + return + } + authCode = m["code"] + if authCode == "" { + log.Errorf("Authentication failed: code not found") + setOAuthStatus(state, "Authentication failed: code not found") + return + } + break waitForCallback } - break } - time.Sleep(500 * time.Millisecond) } // Exchange authorization code for token token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - SetOAuthSessionError(state, "Failed to exchange token") + setOAuthStatus(state, "Failed to exchange token") return } @@ -1060,7 +1073,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - SetOAuthSessionError(state, "Could not get user info") + setOAuthStatus(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -1069,7 +1082,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - SetOAuthSessionError(state, "Failed to execute request") + setOAuthStatus(state, "Failed to execute request") return } defer func() { @@ -1081,7 +1094,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) + setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1090,6 +1103,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") + setOAuthStatus(state, "Failed to get user email from token") } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1097,7 +1111,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - SetOAuthSessionError(state, "Failed to unmarshal token") + setOAuthStatus(state, "Failed to unmarshal token") return } @@ -1125,7 +1139,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { }) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) - SetOAuthSessionError(state, "Failed to get authenticated client") + setOAuthStatus(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1135,12 +1149,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - SetOAuthSessionError(state, "Failed to verify Cloud AI API status") + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1148,26 +1162,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - SetOAuthSessionError(state, "Failed to resolve project ID") + setOAuthStatus(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - SetOAuthSessionError(state, "Failed to verify Cloud AI API status") + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - SetOAuthSessionError(state, "Cloud AI API not enabled") + setOAuthStatus(state, "Cloud AI API not enabled") return } } @@ -1190,15 +1204,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - SetOAuthSessionError(state, "Failed to save token to file") + setOAuthStatus(state, "Failed to save token to file") return } - CompleteOAuthSession(state) - CompleteOAuthSessionsByProvider("gemini") + deleteOAuthStatus(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1234,10 +1248,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { return } - RegisterOAuthSession(state, "codex") - isWebUI := isWebUIRequest(c) - var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") if errTarget != nil { @@ -1245,8 +1256,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - var errStart error - if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { + if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start codex callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1255,43 +1265,49 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarderInstance(codexCallbackPort, forwarder) + defer stopCallbackForwarder(codexCallbackPort) } // Wait for callback file waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state)) - deadline := time.Now().Add(5 * time.Minute) + timer := time.NewTimer(5 * time.Minute) + ticker := time.NewTicker(500 * time.Millisecond) + defer timer.Stop() + defer ticker.Stop() var code string + waitForCallback: for { - if !IsOAuthSessionPending(state, "codex") { + select { + case <-ctx.Done(): + log.WithError(ctx.Err()).Error("oauth flow canceled") + setOAuthStatus(state, "OAuth flow canceled") return - } - if time.Now().After(deadline) { + case <-timer.C: authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - SetOAuthSessionError(state, "Timeout waiting for OAuth callback") + setOAuthStatus(state, "Timeout waiting for OAuth callback") return - } - if data, errR := os.ReadFile(waitFile); errR == nil { - var m map[string]string - _ = json.Unmarshal(data, &m) - _ = os.Remove(waitFile) - if errStr := m["error"]; errStr != "" { - oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) - log.Error(codex.GetUserFriendlyMessage(oauthErr)) - SetOAuthSessionError(state, "Bad Request") - return - } - if m["state"] != state { - authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - SetOAuthSessionError(state, "State code error") - log.Error(codex.GetUserFriendlyMessage(authErr)) - return + case <-ticker.C: + if data, errR := os.ReadFile(waitFile); errR == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(waitFile) + if errStr := m["error"]; errStr != "" { + oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) + log.Error(codex.GetUserFriendlyMessage(oauthErr)) + setOAuthStatus(state, "Bad Request") + return + } + if m["state"] != state { + authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) + setOAuthStatus(state, "State code error") + log.Error(codex.GetUserFriendlyMessage(authErr)) + return + } + code = m["code"] + break waitForCallback } - code = m["code"] - break } - time.Sleep(500 * time.Millisecond) } log.Debug("Authorization code received, exchanging for tokens...") @@ -1315,14 +1331,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") + setOAuthStatus(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1333,7 +1349,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - SetOAuthSessionError(state, "Failed to parse token response") + setOAuthStatus(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1371,8 +1387,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) @@ -1380,10 +1396,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - CompleteOAuthSession(state) - CompleteOAuthSessionsByProvider("codex") + deleteOAuthStatus(state) }() + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1424,10 +1440,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { params.Set("state", state) authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() - RegisterOAuthSession(state, "antigravity") - isWebUI := isWebUIRequest(c) - var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") if errTarget != nil { @@ -1435,8 +1448,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - var errStart error - if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { + if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start antigravity callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1445,44 +1457,50 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) + defer stopCallbackForwarder(antigravityCallbackPort) } waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) - deadline := time.Now().Add(5 * time.Minute) + timer := time.NewTimer(5 * time.Minute) + ticker := time.NewTicker(500 * time.Millisecond) + defer timer.Stop() + defer ticker.Stop() var authCode string + waitForCallback: for { - if !IsOAuthSessionPending(state, "antigravity") { + select { + case <-ctx.Done(): + log.WithError(ctx.Err()).Error("oauth flow canceled") + setOAuthStatus(state, "OAuth flow canceled") return - } - if time.Now().After(deadline) { + case <-timer.C: log.Error("oauth flow timed out") - SetOAuthSessionError(state, "OAuth flow timed out") + setOAuthStatus(state, "OAuth flow timed out") return - } - if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { - var payload map[string]string - _ = json.Unmarshal(data, &payload) - _ = os.Remove(waitFile) - if errStr := strings.TrimSpace(payload["error"]); errStr != "" { - log.Errorf("Authentication failed: %s", errStr) - SetOAuthSessionError(state, "Authentication failed") - return - } - if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { - log.Errorf("Authentication failed: state mismatch") - SetOAuthSessionError(state, "Authentication failed: state mismatch") - return - } - authCode = strings.TrimSpace(payload["code"]) - if authCode == "" { - log.Error("Authentication failed: code not found") - SetOAuthSessionError(state, "Authentication failed: code not found") - return + case <-ticker.C: + if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { + var payload map[string]string + _ = json.Unmarshal(data, &payload) + _ = os.Remove(waitFile) + if errStr := strings.TrimSpace(payload["error"]); errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + setOAuthStatus(state, "Authentication failed") + return + } + if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { + log.Errorf("Authentication failed: state mismatch") + setOAuthStatus(state, "Authentication failed: state mismatch") + return + } + authCode = strings.TrimSpace(payload["code"]) + if authCode == "" { + log.Error("Authentication failed: code not found") + setOAuthStatus(state, "Authentication failed: code not found") + return + } + break waitForCallback } - break } - time.Sleep(500 * time.Millisecond) } httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) @@ -1496,7 +1514,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - SetOAuthSessionError(state, "Failed to build token request") + setOAuthStatus(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1504,7 +1522,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - SetOAuthSessionError(state, "Failed to exchange token") + setOAuthStatus(state, "Failed to exchange token") return } defer func() { @@ -1516,7 +1534,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1528,7 +1546,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - SetOAuthSessionError(state, "Failed to parse token response") + setOAuthStatus(state, "Failed to parse token response") return } @@ -1537,7 +1555,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - SetOAuthSessionError(state, "Failed to build user info request") + setOAuthStatus(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1545,7 +1563,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - SetOAuthSessionError(state, "Failed to execute user info request") + setOAuthStatus(state, "Failed to execute user info request") return } defer func() { @@ -1564,7 +1582,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) + setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1612,12 +1630,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - SetOAuthSessionError(state, "Failed to save token to file") + setOAuthStatus(state, "Failed to save token to file") return } - CompleteOAuthSession(state) - CompleteOAuthSessionsByProvider("antigravity") + deleteOAuthStatus(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1625,6 +1642,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1646,13 +1664,11 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } authURL := deviceFlow.VerificationURIComplete - RegisterOAuthSession(state, "qwen") - go func() { fmt.Println("Waiting for authentication...") - tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) + tokenData, errPollForToken := qwenAuth.PollForToken(ctx, deviceFlow) if errPollForToken != nil { - SetOAuthSessionError(state, "Authentication failed") + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1671,15 +1687,16 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - SetOAuthSessionError(state, "Failed to save authentication tokens") + setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - CompleteOAuthSession(state) + deleteOAuthStatus(state) }() + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1692,10 +1709,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { authSvc := iflowauth.NewIFlowAuth(h.cfg) authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) - RegisterOAuthSession(state, "iflow") - isWebUI := isWebUIRequest(c) - var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") if errTarget != nil { @@ -1703,8 +1717,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) return } - var errStart error - if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { + if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start iflow callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) return @@ -1713,51 +1726,57 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder) + defer stopCallbackForwarder(iflowauth.CallbackPort) } fmt.Println("Waiting for authentication...") waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state)) - deadline := time.Now().Add(5 * time.Minute) + timer := time.NewTimer(5 * time.Minute) + ticker := time.NewTicker(500 * time.Millisecond) + defer timer.Stop() + defer ticker.Stop() var resultMap map[string]string + waitForCallback: for { - if !IsOAuthSessionPending(state, "iflow") { + select { + case <-ctx.Done(): + setOAuthStatus(state, "Authentication canceled") + fmt.Println("Authentication canceled") return - } - if time.Now().After(deadline) { - SetOAuthSessionError(state, "Authentication failed") + case <-timer.C: + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return + case <-ticker.C: + if data, errR := os.ReadFile(waitFile); errR == nil { + _ = os.Remove(waitFile) + _ = json.Unmarshal(data, &resultMap) + break waitForCallback + } } - if data, errR := os.ReadFile(waitFile); errR == nil { - _ = os.Remove(waitFile) - _ = json.Unmarshal(data, &resultMap) - break - } - time.Sleep(500 * time.Millisecond) } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - SetOAuthSessionError(state, "Authentication failed") + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - SetOAuthSessionError(state, "Authentication failed") + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - SetOAuthSessionError(state, "Authentication failed") + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - SetOAuthSessionError(state, "Authentication failed") + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1779,8 +1798,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -1789,10 +1808,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - CompleteOAuthSession(state) - CompleteOAuthSessionsByProvider("iflow") + deleteOAuthStatus(state) }() + setOAuthStatus(state, "") c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2083,7 +2102,11 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage } log.Println("Onboarding in progress, waiting 5 seconds...") - time.Sleep(5 * time.Second) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(5 * time.Second): + } } } @@ -2228,24 +2251,16 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec } func (h *Handler) GetAuthStatus(c *gin.Context) { - state := strings.TrimSpace(c.Query("state")) - if state == "" { - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - return - } - if err := ValidateOAuthState(state); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) - return - } - - _, status, ok := GetOAuthSession(state) - if !ok { - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - return - } - if status != "" { - c.JSON(http.StatusOK, gin.H{"status": "error", "error": status}) - return + state := c.Query("state") + if statusValue, ok := getOAuthStatus(state); ok { + if statusValue != "" { + c.JSON(200, gin.H{"status": "error", "error": statusValue}) + } else { + c.JSON(200, gin.H{"status": "wait"}) + return + } + } else { + c.JSON(200, gin.H{"status": "ok"}) } - c.JSON(http.StatusOK, gin.H{"status": "wait"}) + deleteOAuthStatus(state) } diff --git a/internal/api/server.go b/internal/api/server.go index 239dc641c..072ee3dcd 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -327,6 +327,16 @@ func (s *Server) setupRoutes() { v1.POST("/responses", openaiResponsesHandlers.Responses) } + // Anthropic-compatible API routes + anthropic := s.engine.Group("/anthropic") + anthropic.Use(AuthMiddleware(s.accessManager)) + { + v1Anthropic := anthropic.Group("/v1") + v1Anthropic.GET("/models", claudeCodeHandlers.ClaudeModels) + v1Anthropic.POST("/messages", claudeCodeHandlers.ClaudeMessages) + v1Anthropic.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) + } + // Gemini compatible API routes v1beta := s.engine.Group("/v1beta") v1beta.Use(AuthMiddleware(s.accessManager)) diff --git a/internal/auth/claude/anthropic_auth.go b/internal/auth/claude/anthropic_auth.go index 07bd5b429..160009226 100644 --- a/internal/auth/claude/anthropic_auth.go +++ b/internal/auth/claude/anthropic_auth.go @@ -4,25 +4,26 @@ package claude import ( + "bytes" "context" "encoding/json" "fmt" - "io" "net/http" "net/url" "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) const ( - anthropicAuthURL = "https://claude.ai/oauth/authorize" - anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" - anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" - redirectURI = "http://localhost:54545/callback" + anthropicAuthURL = "https://claude.ai/oauth/authorize" + anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" + anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + defaultRedirectURI = "http://localhost:54545/callback" ) // tokenResponse represents the response structure from Anthropic's OAuth token endpoint. @@ -58,8 +59,9 @@ type ClaudeAuth struct { // Returns: // - *ClaudeAuth: A new Claude authentication service instance func NewClaudeAuth(cfg *config.Config) *ClaudeAuth { + client := &http.Client{Timeout: 30 * time.Second} return &ClaudeAuth{ - httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}), + httpClient: util.SetOAuthProxy(&cfg.SDKConfig, client), } } @@ -76,9 +78,17 @@ func NewClaudeAuth(cfg *config.Config) *ClaudeAuth { // - string: The state parameter for verification // - error: An error if PKCE codes are missing or URL generation fails func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, string, error) { + return o.GenerateAuthURLWithRedirectURI(state, pkceCodes, defaultRedirectURI) +} + +func (o *ClaudeAuth) GenerateAuthURLWithRedirectURI(state string, pkceCodes *PKCECodes, redirectURI string) (string, string, error) { if pkceCodes == nil { return "", "", fmt.Errorf("PKCE codes are required") } + redirectURI = strings.TrimSpace(redirectURI) + if redirectURI == "" { + redirectURI = defaultRedirectURI + } params := url.Values{ "code": {"true"}, @@ -127,10 +137,18 @@ func (c *ClaudeAuth) parseCodeAndState(code string) (parsedCode, parsedState str // - *ClaudeAuthBundle: The complete authentication bundle with tokens // - error: An error if token exchange fails func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state string, pkceCodes *PKCECodes) (*ClaudeAuthBundle, error) { + return o.ExchangeCodeForTokensWithRedirectURI(ctx, code, state, pkceCodes, defaultRedirectURI) +} + +func (o *ClaudeAuth) ExchangeCodeForTokensWithRedirectURI(ctx context.Context, code, state string, pkceCodes *PKCECodes, redirectURI string) (*ClaudeAuthBundle, error) { if pkceCodes == nil { return nil, fmt.Errorf("PKCE codes are required for token exchange") } newCode, newState := o.parseCodeAndState(code) + redirectURI = strings.TrimSpace(redirectURI) + if redirectURI == "" { + redirectURI = defaultRedirectURI + } // Prepare token exchange request reqBody := map[string]interface{}{ @@ -152,35 +170,33 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri return nil, fmt.Errorf("failed to marshal request body: %w", err) } - // log.Debugf("Token exchange request: %s", string(jsonBody)) - - req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody))) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - - resp, err := o.httpClient.Do(req) - if err != nil { + status, _, body, err := oauthhttp.Do( + ctx, + o.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, anthropicTokenURL, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("token exchange request failed: %w", err) } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("failed to close response body: %v", errClose) + if status != http.StatusOK { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("token exchange failed with status %d: %s: %w", status, msg, err) } - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read token response: %w", err) + return nil, fmt.Errorf("token exchange failed with status %d: %s", status, msg) } - // log.Debugf("Token response: %s", string(body)) - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) } - // log.Debugf("Token response: %s", string(body)) var tokenResp tokenResponse if err = json.Unmarshal(body, &tokenResp); err != nil { @@ -231,33 +247,34 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C return nil, fmt.Errorf("failed to marshal request body: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody))) - if err != nil { - return nil, fmt.Errorf("failed to create refresh request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - - resp, err := o.httpClient.Do(req) - if err != nil { + status, _, body, err := oauthhttp.Do( + ctx, + o.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, anthropicTokenURL, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("token refresh request failed: %w", err) } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read refresh response: %w", err) + if status != http.StatusOK { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("token refresh failed with status %d: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("token refresh failed with status %d: %s", status, msg) } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body)) + if err != nil { + return nil, fmt.Errorf("token refresh request failed: %w", err) } - // log.Debugf("Token response: %s", string(body)) - var tokenResp tokenResponse if err = json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("failed to parse token response: %w", err) diff --git a/internal/auth/claude/oauth_provider.go b/internal/auth/claude/oauth_provider.go new file mode 100644 index 000000000..0cbb3e21e --- /dev/null +++ b/internal/auth/claude/oauth_provider.go @@ -0,0 +1,103 @@ +package claude + +import ( + "context" + "fmt" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" +) + +// OAuthProvider adapts ClaudeAuth to the shared oauthflow.ProviderOAuth interface. +type OAuthProvider struct { + auth *ClaudeAuth +} + +func NewOAuthProvider(auth *ClaudeAuth) *OAuthProvider { + return &OAuthProvider{auth: auth} +} + +func (p *OAuthProvider) Provider() string { + return "claude" +} + +func (p *OAuthProvider) AuthorizeURL(session oauthflow.OAuthSession) (string, oauthflow.OAuthSession, error) { + if p == nil || p.auth == nil { + return "", session, fmt.Errorf("claude oauth provider: auth is nil") + } + pkce := &PKCECodes{ + CodeVerifier: session.CodeVerifier, + CodeChallenge: session.CodeChallenge, + } + authURL, returnedState, err := p.auth.GenerateAuthURLWithRedirectURI(session.State, pkce, session.RedirectURI) + if err != nil { + return "", session, err + } + session.State = returnedState + return authURL, session, nil +} + +func (p *OAuthProvider) ExchangeCode(ctx context.Context, session oauthflow.OAuthSession, code string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("claude oauth provider: auth is nil") + } + pkce := &PKCECodes{ + CodeVerifier: session.CodeVerifier, + CodeChallenge: session.CodeChallenge, + } + bundle, err := p.auth.ExchangeCodeForTokensWithRedirectURI(ctx, code, session.State, pkce, session.RedirectURI) + if err != nil { + return nil, err + } + if bundle == nil { + return nil, fmt.Errorf("claude oauth provider: token bundle is nil") + } + + meta := map[string]any{} + if email := strings.TrimSpace(bundle.TokenData.Email); email != "" { + meta["email"] = email + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(bundle.TokenData.AccessToken), + RefreshToken: strings.TrimSpace(bundle.TokenData.RefreshToken), + ExpiresAt: strings.TrimSpace(bundle.TokenData.Expire), + TokenType: "Bearer", + Metadata: meta, + }, nil +} + +func (p *OAuthProvider) Refresh(ctx context.Context, refreshToken string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("claude oauth provider: auth is nil") + } + data, err := p.auth.RefreshTokens(ctx, refreshToken) + if err != nil { + return nil, err + } + if data == nil { + return nil, fmt.Errorf("claude oauth provider: refresh result is nil") + } + + meta := map[string]any{} + if email := strings.TrimSpace(data.Email); email != "" { + meta["email"] = email + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(data.AccessToken), + RefreshToken: strings.TrimSpace(data.RefreshToken), + ExpiresAt: strings.TrimSpace(data.Expire), + TokenType: "Bearer", + Metadata: meta, + }, nil +} + +// Revoke invalidates the given token at Anthropic. +// Note: Anthropic does not currently provide a public token revocation endpoint, +// so this method returns ErrRevokeNotSupported. +func (p *OAuthProvider) Revoke(ctx context.Context, token string) error { + // Anthropic does not currently support OAuth token revocation via public API. + // Users should revoke tokens through the Anthropic console. + return oauthflow.ErrRevokeNotSupported +} diff --git a/internal/auth/claude/oauth_server.go b/internal/auth/claude/oauth_server.go index a6ebe2f7b..500711b72 100644 --- a/internal/auth/claude/oauth_server.go +++ b/internal/auth/claude/oauth_server.go @@ -11,6 +11,7 @@ import ( "net/http" "strings" "sync" + "syscall" "time" log "github.com/sirupsen/logrus" @@ -21,7 +22,8 @@ import ( // and captures the necessary parameters to complete the authentication flow. type OAuthServer struct { // server is the underlying HTTP server instance - server *http.Server + server *http.Server + listener net.Listener // port is the port number on which the server listens port int // resultChan is a channel for sending OAuth results @@ -63,6 +65,16 @@ func NewOAuthServer(port int) *OAuthServer { } } +// Port returns the actual bound port once the server has started. +func (s *OAuthServer) Port() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.port +} + // Start starts the OAuth callback server. // It sets up the HTTP handlers for the callback and success endpoints, // and begins listening on the specified port. @@ -77,34 +89,39 @@ func (s *OAuthServer) Start() error { return fmt.Errorf("server is already running") } - // Check if port is available - if !s.isPortAvailable() { - return fmt.Errorf("port %d is already in use", s.port) - } - mux := http.NewServeMux() mux.HandleFunc("/callback", s.handleCallback) mux.HandleFunc("/success", s.handleSuccess) + addr := fmt.Sprintf("127.0.0.1:%d", s.port) + ln, err := net.Listen("tcp", addr) + if err != nil { + if errors.Is(err, syscall.EADDRINUSE) { + return fmt.Errorf("port %d is already in use", s.port) + } + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + s.server = &http.Server{ - Addr: fmt.Sprintf(":%d", s.port), - Handler: mux, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + s.listener = ln + if tcp, ok := ln.Addr().(*net.TCPAddr); ok { + s.port = tcp.Port } s.running = true // Start server in goroutine go func() { - if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := s.server.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { s.errorChan <- fmt.Errorf("server failed to start: %w", err) } }() - // Give server a moment to start - time.Sleep(100 * time.Millisecond) - return nil } @@ -131,6 +148,10 @@ func (s *OAuthServer) Stop(ctx context.Context) error { defer cancel() err := s.server.Shutdown(shutdownCtx) + if s.listener != nil { + _ = s.listener.Close() + s.listener = nil + } s.running = false s.server = nil @@ -242,6 +263,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://console.anthropic.com/" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://console.anthropic.com/" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -251,6 +277,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. @@ -292,23 +324,6 @@ func (s *OAuthServer) sendResult(result *OAuthResult) { } } -// isPortAvailable checks if the specified port is available. -// It attempts to listen on the port to determine availability. -// -// Returns: -// - bool: True if the port is available, false otherwise -func (s *OAuthServer) isPortAvailable() bool { - addr := fmt.Sprintf(":%d", s.port) - listener, err := net.Listen("tcp", addr) - if err != nil { - return false - } - defer func() { - _ = listener.Close() - }() - return true -} - // IsRunning returns whether the server is currently running. // // Returns: diff --git a/internal/auth/codex/oauth_provider.go b/internal/auth/codex/oauth_provider.go new file mode 100644 index 000000000..d924dc8c8 --- /dev/null +++ b/internal/auth/codex/oauth_provider.go @@ -0,0 +1,110 @@ +package codex + +import ( + "context" + "fmt" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" +) + +// OAuthProvider adapts CodexAuth to the shared oauthflow.ProviderOAuth interface. +type OAuthProvider struct { + auth *CodexAuth +} + +func NewOAuthProvider(auth *CodexAuth) *OAuthProvider { + return &OAuthProvider{auth: auth} +} + +func (p *OAuthProvider) Provider() string { + return "codex" +} + +func (p *OAuthProvider) AuthorizeURL(session oauthflow.OAuthSession) (string, oauthflow.OAuthSession, error) { + if p == nil || p.auth == nil { + return "", session, fmt.Errorf("codex oauth provider: auth is nil") + } + pkce := &PKCECodes{ + CodeVerifier: session.CodeVerifier, + CodeChallenge: session.CodeChallenge, + } + authURL, err := p.auth.GenerateAuthURLWithRedirectURI(session.State, pkce, session.RedirectURI) + if err != nil { + return "", session, err + } + return authURL, session, nil +} + +func (p *OAuthProvider) ExchangeCode(ctx context.Context, session oauthflow.OAuthSession, code string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("codex oauth provider: auth is nil") + } + pkce := &PKCECodes{ + CodeVerifier: session.CodeVerifier, + CodeChallenge: session.CodeChallenge, + } + bundle, err := p.auth.ExchangeCodeForTokensWithRedirectURI(ctx, code, pkce, session.RedirectURI) + if err != nil { + return nil, err + } + if bundle == nil { + return nil, fmt.Errorf("codex oauth provider: token bundle is nil") + } + + meta := map[string]any{} + if email := strings.TrimSpace(bundle.TokenData.Email); email != "" { + meta["email"] = email + } + if accountID := strings.TrimSpace(bundle.TokenData.AccountID); accountID != "" { + meta["account_id"] = accountID + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(bundle.TokenData.AccessToken), + RefreshToken: strings.TrimSpace(bundle.TokenData.RefreshToken), + ExpiresAt: strings.TrimSpace(bundle.TokenData.Expire), + TokenType: "Bearer", + IDToken: strings.TrimSpace(bundle.TokenData.IDToken), + Metadata: meta, + }, nil +} + +func (p *OAuthProvider) Refresh(ctx context.Context, refreshToken string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("codex oauth provider: auth is nil") + } + data, err := p.auth.RefreshTokens(ctx, refreshToken) + if err != nil { + return nil, err + } + if data == nil { + return nil, fmt.Errorf("codex oauth provider: refresh result is nil") + } + + meta := map[string]any{} + if email := strings.TrimSpace(data.Email); email != "" { + meta["email"] = email + } + if accountID := strings.TrimSpace(data.AccountID); accountID != "" { + meta["account_id"] = accountID + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(data.AccessToken), + RefreshToken: strings.TrimSpace(data.RefreshToken), + ExpiresAt: strings.TrimSpace(data.Expire), + TokenType: "Bearer", + IDToken: strings.TrimSpace(data.IDToken), + Metadata: meta, + }, nil +} + +// Revoke invalidates the given token at OpenAI. +// Note: OpenAI does not currently provide a public token revocation endpoint, +// so this method returns ErrRevokeNotSupported. +func (p *OAuthProvider) Revoke(ctx context.Context, token string) error { + // OpenAI does not currently support OAuth token revocation via public API. + // Users should revoke tokens through the OpenAI platform dashboard. + return oauthflow.ErrRevokeNotSupported +} diff --git a/internal/auth/codex/oauth_server.go b/internal/auth/codex/oauth_server.go index 9c6a6c5b7..71a717330 100644 --- a/internal/auth/codex/oauth_server.go +++ b/internal/auth/codex/oauth_server.go @@ -8,6 +8,7 @@ import ( "net/http" "strings" "sync" + "syscall" "time" log "github.com/sirupsen/logrus" @@ -18,7 +19,8 @@ import ( // and captures the necessary parameters to complete the authentication flow. type OAuthServer struct { // server is the underlying HTTP server instance - server *http.Server + server *http.Server + listener net.Listener // port is the port number on which the server listens port int // resultChan is a channel for sending OAuth results @@ -60,6 +62,16 @@ func NewOAuthServer(port int) *OAuthServer { } } +// Port returns the actual bound port once the server has started. +func (s *OAuthServer) Port() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.port +} + // Start starts the OAuth callback server. // It sets up the HTTP handlers for the callback and success endpoints, // and begins listening on the specified port. @@ -74,34 +86,39 @@ func (s *OAuthServer) Start() error { return fmt.Errorf("server is already running") } - // Check if port is available - if !s.isPortAvailable() { - return fmt.Errorf("port %d is already in use", s.port) - } - mux := http.NewServeMux() mux.HandleFunc("/auth/callback", s.handleCallback) mux.HandleFunc("/success", s.handleSuccess) + addr := fmt.Sprintf("127.0.0.1:%d", s.port) + ln, err := net.Listen("tcp", addr) + if err != nil { + if errors.Is(err, syscall.EADDRINUSE) { + return fmt.Errorf("port %d is already in use", s.port) + } + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + s.server = &http.Server{ - Addr: fmt.Sprintf(":%d", s.port), - Handler: mux, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + s.listener = ln + if tcp, ok := ln.Addr().(*net.TCPAddr); ok { + s.port = tcp.Port } s.running = true // Start server in goroutine go func() { - if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := s.server.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { s.errorChan <- fmt.Errorf("server failed to start: %w", err) } }() - // Give server a moment to start - time.Sleep(100 * time.Millisecond) - return nil } @@ -128,6 +145,10 @@ func (s *OAuthServer) Stop(ctx context.Context) error { defer cancel() err := s.server.Shutdown(shutdownCtx) + if s.listener != nil { + _ = s.listener.Close() + s.listener = nil + } s.running = false s.server = nil @@ -239,6 +260,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://platform.openai.com" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://platform.openai.com" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -248,6 +274,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. @@ -289,23 +321,6 @@ func (s *OAuthServer) sendResult(result *OAuthResult) { } } -// isPortAvailable checks if the specified port is available. -// It attempts to listen on the port to determine availability. -// -// Returns: -// - bool: True if the port is available, false otherwise -func (s *OAuthServer) isPortAvailable() bool { - addr := fmt.Sprintf(":%d", s.port) - listener, err := net.Listen("tcp", addr) - if err != nil { - return false - } - defer func() { - _ = listener.Close() - }() - return true -} - // IsRunning returns whether the server is currently running. // // Returns: diff --git a/internal/auth/codex/openai_auth.go b/internal/auth/codex/openai_auth.go index c0299c3d9..d80c260cd 100644 --- a/internal/auth/codex/openai_auth.go +++ b/internal/auth/codex/openai_auth.go @@ -8,22 +8,22 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/url" "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) const ( - openaiAuthURL = "https://auth.openai.com/oauth/authorize" - openaiTokenURL = "https://auth.openai.com/oauth/token" - openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - redirectURI = "http://localhost:1455/auth/callback" + openaiAuthURL = "https://auth.openai.com/oauth/authorize" + openaiTokenURL = "https://auth.openai.com/oauth/token" + openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + defaultRedirectURI = "http://localhost:1455/auth/callback" ) // CodexAuth handles the OpenAI OAuth2 authentication flow. @@ -36,8 +36,9 @@ type CodexAuth struct { // NewCodexAuth creates a new CodexAuth service instance. // It initializes an HTTP client with proxy settings from the provided configuration. func NewCodexAuth(cfg *config.Config) *CodexAuth { + client := &http.Client{Timeout: 30 * time.Second} return &CodexAuth{ - httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}), + httpClient: util.SetOAuthProxy(&cfg.SDKConfig, client), } } @@ -45,9 +46,17 @@ func NewCodexAuth(cfg *config.Config) *CodexAuth { // It constructs the URL with the necessary parameters, including the client ID, // response type, redirect URI, scopes, and PKCE challenge. func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, error) { + return o.GenerateAuthURLWithRedirectURI(state, pkceCodes, defaultRedirectURI) +} + +func (o *CodexAuth) GenerateAuthURLWithRedirectURI(state string, pkceCodes *PKCECodes, redirectURI string) (string, error) { if pkceCodes == nil { return "", fmt.Errorf("PKCE codes are required") } + redirectURI = strings.TrimSpace(redirectURI) + if redirectURI == "" { + redirectURI = defaultRedirectURI + } params := url.Values{ "client_id": {openaiClientID}, @@ -70,9 +79,17 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, // It performs an HTTP POST request to the OpenAI token endpoint with the provided // authorization code and PKCE verifier. func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) { + return o.ExchangeCodeForTokensWithRedirectURI(ctx, code, pkceCodes, defaultRedirectURI) +} + +func (o *CodexAuth) ExchangeCodeForTokensWithRedirectURI(ctx context.Context, code string, pkceCodes *PKCECodes, redirectURI string) (*CodexAuthBundle, error) { if pkceCodes == nil { return nil, fmt.Errorf("PKCE codes are required for token exchange") } + redirectURI = strings.TrimSpace(redirectURI) + if redirectURI == "" { + redirectURI = defaultRedirectURI + } // Prepare token exchange request data := url.Values{ @@ -82,31 +99,34 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce "redirect_uri": {redirectURI}, "code_verifier": {pkceCodes.CodeVerifier}, } - - req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - resp, err := o.httpClient.Do(req) - if err != nil { + encoded := data.Encode() + + status, _, body, err := oauthhttp.Do( + ctx, + o.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openaiTokenURL, strings.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("token exchange request failed: %w", err) } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read token response: %w", err) + if status != http.StatusOK { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("token exchange failed with status %d: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("token exchange failed with status %d: %s", status, msg) } - // log.Debugf("Token response: %s", string(body)) - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) } // Parse token response @@ -168,30 +188,34 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co "refresh_token": {refreshToken}, "scope": {"openid profile email"}, } - - req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create refresh request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - resp, err := o.httpClient.Do(req) - if err != nil { + encoded := data.Encode() + + status, _, body, err := oauthhttp.Do( + ctx, + o.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openaiTokenURL, strings.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("token refresh request failed: %w", err) } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read refresh response: %w", err) + if status != http.StatusOK { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("token refresh failed with status %d: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("token refresh failed with status %d: %s", status, msg) } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body)) + if err != nil { + return nil, fmt.Errorf("token refresh request failed: %w", err) } var tokenResp struct { diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index 7b18e7384..f60365c52 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -9,20 +9,18 @@ import ( "encoding/json" "errors" "fmt" - "io" - "net" "net/http" - "net/url" + "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "golang.org/x/net/proxy" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -47,12 +45,6 @@ var ( type GeminiAuth struct { } -// WebLoginOptions customizes the interactive OAuth flow. -type WebLoginOptions struct { - NoBrowser bool - Prompt func(string) (string, error) -} - // NewGeminiAuth creates a new instance of GeminiAuth. func NewGeminiAuth() *GeminiAuth { return &GeminiAuth{} @@ -66,41 +58,14 @@ func NewGeminiAuth() *GeminiAuth { // - ctx: The context for the HTTP client // - ts: The Gemini token storage containing authentication tokens // - cfg: The configuration containing proxy settings -// - opts: Optional parameters to customize browser and prompt behavior +// - opts: Optional web login configuration (e.g., disable browser opening) // // Returns: // - *http.Client: An HTTP client configured with authentication // - error: An error if the client configuration fails, nil otherwise func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { - // Configure proxy settings for the HTTP client if a proxy URL is provided. - proxyURL, err := url.Parse(cfg.ProxyURL) - if err == nil { - var transport *http.Transport - if proxyURL.Scheme == "socks5" { - // Handle SOCKS5 proxy. - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - auth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) - } - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Handle HTTP/HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - - if transport != nil { - proxyClient := &http.Client{Transport: transport} - ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) - } - } + oauthHTTPClient := util.SetOAuthProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}) + ctx = context.WithValue(ctx, oauth2.HTTPClient, oauthHTTPClient) // Configure the OAuth2 client. conf := &oauth2.Config{ @@ -112,11 +77,12 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken } var token *oauth2.Token + var err error // If no token is found in storage, initiate the web-based OAuth flow. if ts.Token == nil { fmt.Printf("Could not load token from file, starting OAuth flow.\n") - token, err = g.getTokenFromWeb(ctx, conf, opts) + token, err = g.getTokenFromWeb(ctx, oauthHTTPClient, opts) if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } @@ -153,26 +119,32 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken // - error: An error if the token storage creation fails, nil otherwise func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*GeminiTokenStorage, error) { httpClient := config.Client(ctx, token) - req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) - if err != nil { - return nil, fmt.Errorf("could not get user info: %v", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - - resp, err := httpClient.Do(req) - if err != nil { + status, _, bodyBytes, err := oauthhttp.Do( + ctx, + httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) + if err != nil { + return nil, fmt.Errorf("could not get user info: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("failed to execute request: %w", err) } - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) + if status < http.StatusOK || status >= http.StatusMultipleChoices { + msg := strings.TrimSpace(string(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("get user info request failed with status %d: %s: %w", status, msg, err) } - }() - - bodyBytes, _ := io.ReadAll(resp.Body) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf("get user info request failed with status %d: %s", status, msg) + } + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) } emailResult := gjson.GetBytes(bodyBytes, "email") @@ -212,161 +184,91 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // Parameters: // - ctx: The context for the HTTP client // - config: The OAuth2 configuration -// - opts: Optional parameters to customize browser and prompt behavior +// - opts: Optional web login configuration (e.g., disable browser opening) // // Returns: // - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - error: An error if the token acquisition fails, nil otherwise -func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { - // Use a channel to pass the authorization code from the HTTP handler to the main function. - codeChan := make(chan string, 1) - errChan := make(chan error, 1) - - // Create a new HTTP server with its own multiplexer. - mux := http.NewServeMux() - server := &http.Server{Addr: ":8085", Handler: mux} - config.RedirectURL = "http://localhost:8085/oauth2callback" - - mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { - if err := r.URL.Query().Get("error"); err != "" { - _, _ = fmt.Fprintf(w, "Authentication failed: %s", err) - select { - case errChan <- fmt.Errorf("authentication failed via callback: %s", err): - default: - } - return - } - code := r.URL.Query().Get("code") - if code == "" { - _, _ = fmt.Fprint(w, "Authentication failed: code not found.") - select { - case errChan <- fmt.Errorf("code not found in callback"): - default: - } - return - } - _, _ = fmt.Fprint(w, "

Authentication successful!

You can close this window.

") - select { - case codeChan <- code: - default: - } - }) - - // Start the server in a goroutine. - go func() { - if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - log.Errorf("ListenAndServe(): %v", err) - select { - case errChan <- err: - default: - } - } - }() - - // Open the authorization URL in the user's browser. - authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - - noBrowser := false +func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, httpClient *http.Client, opts *WebLoginOptions) (*oauth2.Token, error) { + if ctx == nil { + ctx = context.Background() + } + provider := NewOAuthProvider(httpClient) + desiredPort := 8085 + noBrowser := true if opts != nil { noBrowser = opts.NoBrowser } - if !noBrowser { - fmt.Println("Opening browser for authentication...") - - // Check if browser is available - if !browser.IsAvailable() { - log.Warn("No browser available on this system") - util.PrintSSHTunnelInstructions(8085) - fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) - } else { - if err := browser.OpenURL(authURL); err != nil { - authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err) - log.Warn(codex.GetUserFriendlyMessage(authErr)) - util.PrintSSHTunnelInstructions(8085) - fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) - - // Log platform info for debugging - platformInfo := browser.GetPlatformInfo() - log.Debugf("Browser platform info: %+v", platformInfo) - } else { - log.Debug("Browser opened successfully") + flow, err := oauthflow.RunAuthCodeFlow(ctx, provider, oauthflow.AuthCodeFlowOptions{ + DesiredPort: desiredPort, + CallbackPath: "/oauth2callback", + Timeout: 5 * time.Minute, + OnAuthURL: func(authURL string, callbackPort int, redirectURI string) { + if desiredPort != 0 && callbackPort != desiredPort { + log.Warnf("gemini oauth: default port %d is busy, falling back to dynamic port", desiredPort) } - } - } else { - util.PrintSSHTunnelInstructions(8085) - fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL) - } - - fmt.Println("Waiting for authentication callback...") - - // Wait for the authorization code or an error. - var authCode string - timeoutTimer := time.NewTimer(5 * time.Minute) - defer timeoutTimer.Stop() - - var manualPromptTimer *time.Timer - var manualPromptC <-chan time.Time - if opts != nil && opts.Prompt != nil { - manualPromptTimer = time.NewTimer(15 * time.Second) - manualPromptC = manualPromptTimer.C - defer manualPromptTimer.Stop() - } -waitForCallback: - for { - select { - case code := <-codeChan: - authCode = code - break waitForCallback - case err := <-errChan: - return nil, err - case <-manualPromptC: - manualPromptC = nil - if manualPromptTimer != nil { - manualPromptTimer.Stop() + opened := false + if !noBrowser { + fmt.Println("Opening browser for authentication...") + if !browser.IsAvailable() { + log.Warn("No browser available on this system") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) + } else if errOpen := browser.OpenURL(authURL); errOpen != nil { + authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, errOpen) + log.Warn(codex.GetUserFriendlyMessage(authErr)) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) + platformInfo := browser.GetPlatformInfo() + log.Debugf("Browser platform info: %+v", platformInfo) + } else { + log.Debug("Browser opened successfully") + opened = true + } } - select { - case code := <-codeChan: - authCode = code - break waitForCallback - case err := <-errChan: - return nil, err - default: - } - input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") - if err != nil { - return nil, err - } - parsed, err := misc.ParseOAuthCallback(input) - if err != nil { - return nil, err - } - if parsed == nil { - continue - } - if parsed.Error != "" { - return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error) + + if !opened { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL) } - if parsed.Code == "" { - return nil, fmt.Errorf("code not found in callback") + fmt.Println("Waiting for authentication callback...") + }, + }) + if err != nil { + var flowErr *oauthflow.FlowError + if errors.As(err, &flowErr) && flowErr != nil { + switch flowErr.Kind { + case oauthflow.FlowErrorKindPortInUse: + return nil, fmt.Errorf("gemini oauth callback port in use: %w", err) + case oauthflow.FlowErrorKindServerStartFailed: + return nil, fmt.Errorf("gemini oauth callback server failed: %w", err) + case oauthflow.FlowErrorKindCallbackTimeout: + return nil, fmt.Errorf("oauth flow timed out") + case oauthflow.FlowErrorKindProviderError: + return nil, fmt.Errorf("authentication failed via callback: %w", flowErr.Err) + case oauthflow.FlowErrorKindInvalidState: + return nil, fmt.Errorf("state mismatch in callback") + case oauthflow.FlowErrorKindCodeExchangeFailed: + return nil, fmt.Errorf("failed to exchange token: %w", flowErr.Err) } - authCode = parsed.Code - break waitForCallback - case <-timeoutTimer.C: - return nil, fmt.Errorf("oauth flow timed out") } + return nil, err } - - // Shutdown the server. - if err := server.Shutdown(ctx); err != nil { - log.Errorf("Failed to shut down server: %v", err) + if flow == nil || flow.Token == nil { + return nil, fmt.Errorf("oauth flow failed: missing token result") } - // Exchange the authorization code for a token. - token, err := config.Exchange(ctx, authCode) - if err != nil { - return nil, fmt.Errorf("failed to exchange token: %w", err) + token := &oauth2.Token{ + AccessToken: flow.Token.AccessToken, + RefreshToken: flow.Token.RefreshToken, + TokenType: flow.Token.TokenType, + } + if strings.TrimSpace(flow.Token.ExpiresAt) != "" { + if expiry, errParse := time.Parse(time.RFC3339, strings.TrimSpace(flow.Token.ExpiresAt)); errParse == nil { + token.Expiry = expiry + } } fmt.Println("Authentication successful.") diff --git a/internal/auth/gemini/oauth_provider.go b/internal/auth/gemini/oauth_provider.go new file mode 100644 index 000000000..08169e4e5 --- /dev/null +++ b/internal/auth/gemini/oauth_provider.go @@ -0,0 +1,229 @@ +package gemini + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" +) + +// OAuthProvider adapts Gemini OAuth to the shared oauthflow.ProviderOAuth interface. +type OAuthProvider struct { + httpClient *http.Client +} + +func NewOAuthProvider(httpClient *http.Client) *OAuthProvider { + return &OAuthProvider{httpClient: httpClient} +} + +func (p *OAuthProvider) Provider() string { + return "gemini" +} + +func (p *OAuthProvider) AuthorizeURL(session oauthflow.OAuthSession) (string, oauthflow.OAuthSession, error) { + if p == nil { + return "", session, fmt.Errorf("gemini oauth provider: provider is nil") + } + redirectURI := strings.TrimSpace(session.RedirectURI) + if redirectURI == "" { + return "", session, fmt.Errorf("gemini oauth provider: redirect URI is empty") + } + + params := url.Values{} + params.Set("access_type", "offline") + params.Set("client_id", geminiOauthClientID) + params.Set("prompt", "consent") + params.Set("redirect_uri", redirectURI) + params.Set("response_type", "code") + params.Set("scope", strings.Join(geminiOauthScopes, " ")) + params.Set("state", session.State) + if strings.TrimSpace(session.CodeChallenge) != "" { + params.Set("code_challenge", strings.TrimSpace(session.CodeChallenge)) + params.Set("code_challenge_method", "S256") + } + return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode(), session, nil +} + +type googleTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + IDToken string `json:"id_token"` +} + +func (p *OAuthProvider) ExchangeCode(ctx context.Context, session oauthflow.OAuthSession, code string) (*oauthflow.TokenResult, error) { + if p == nil || p.httpClient == nil { + return nil, fmt.Errorf("gemini oauth provider: http client is nil") + } + if ctx == nil { + ctx = context.Background() + } + code = strings.TrimSpace(code) + if code == "" { + return nil, fmt.Errorf("gemini oauth provider: authorization code is empty") + } + + data := url.Values{} + data.Set("code", code) + data.Set("client_id", geminiOauthClientID) + data.Set("client_secret", geminiOauthClientSecret) + data.Set("redirect_uri", strings.TrimSpace(session.RedirectURI)) + data.Set("grant_type", "authorization_code") + if strings.TrimSpace(session.CodeVerifier) != "" { + data.Set("code_verifier", strings.TrimSpace(session.CodeVerifier)) + } + + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + ctx, + p.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return nil, err + } + if status < http.StatusOK || status >= http.StatusMultipleChoices { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("gemini oauth token exchange failed: status %d: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("gemini oauth token exchange failed: status %d: %s", status, msg) + } + if err != nil { + return nil, err + } + + var token googleTokenResponse + if errDecode := json.Unmarshal(body, &token); errDecode != nil { + return nil, errDecode + } + if strings.TrimSpace(token.AccessToken) == "" { + return nil, fmt.Errorf("gemini oauth token exchange failed: empty access token") + } + + tokenType := strings.TrimSpace(token.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + expiresAt := "" + if token.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Format(time.RFC3339) + } + + meta := map[string]any{ + "expires_in": token.ExpiresIn, + } + if strings.TrimSpace(token.Scope) != "" { + meta["scope"] = strings.TrimSpace(token.Scope) + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(token.AccessToken), + RefreshToken: strings.TrimSpace(token.RefreshToken), + ExpiresAt: expiresAt, + TokenType: tokenType, + IDToken: strings.TrimSpace(token.IDToken), + Metadata: meta, + }, nil +} + +func (p *OAuthProvider) Refresh(ctx context.Context, refreshToken string) (*oauthflow.TokenResult, error) { + if p == nil || p.httpClient == nil { + return nil, fmt.Errorf("gemini oauth provider: http client is nil") + } + if ctx == nil { + ctx = context.Background() + } + refreshToken = strings.TrimSpace(refreshToken) + if refreshToken == "" { + return nil, fmt.Errorf("gemini oauth provider: refresh token is empty") + } + + data := url.Values{} + data.Set("refresh_token", refreshToken) + data.Set("client_id", geminiOauthClientID) + data.Set("client_secret", geminiOauthClientSecret) + data.Set("grant_type", "refresh_token") + + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + ctx, + p.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return nil, err + } + if status < http.StatusOK || status >= http.StatusMultipleChoices { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("gemini oauth token refresh failed: status %d: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("gemini oauth token refresh failed: status %d: %s", status, msg) + } + if err != nil { + return nil, err + } + + var token googleTokenResponse + if errDecode := json.Unmarshal(body, &token); errDecode != nil { + return nil, errDecode + } + if strings.TrimSpace(token.AccessToken) == "" { + return nil, fmt.Errorf("gemini oauth token refresh failed: empty access token") + } + + tokenType := strings.TrimSpace(token.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + expiresAt := "" + if token.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Format(time.RFC3339) + } + + meta := map[string]any{ + "expires_in": token.ExpiresIn, + } + if strings.TrimSpace(token.Scope) != "" { + meta["scope"] = strings.TrimSpace(token.Scope) + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(token.AccessToken), + RefreshToken: refreshToken, + ExpiresAt: expiresAt, + TokenType: tokenType, + Metadata: meta, + }, nil +} + +func (p *OAuthProvider) Revoke(ctx context.Context, token string) error { + return oauthflow.ErrRevokeNotSupported +} diff --git a/internal/auth/gemini/options.go b/internal/auth/gemini/options.go new file mode 100644 index 000000000..e485b96ab --- /dev/null +++ b/internal/auth/gemini/options.go @@ -0,0 +1,9 @@ +package gemini + +// WebLoginOptions provides optional behavior for Gemini OAuth login flows. +type WebLoginOptions struct { + // NoBrowser disables automatic browser opening when true. + NoBrowser bool + // Prompt can be used by callers to customize interactive prompts. + Prompt func(prompt string) (string, error) +} diff --git a/internal/auth/generic/generic.go b/internal/auth/generic/generic.go new file mode 100644 index 000000000..6af44b90a --- /dev/null +++ b/internal/auth/generic/generic.go @@ -0,0 +1,193 @@ +package generic + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +func tokenFingerprint(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:16]) +} + +// CheckOAuth2Token validates a token against a generic OAuth2 introspection endpoint. +func CheckOAuth2Token(ctx context.Context, token string, cfg config.GenericAuth) (*coreauth.Auth, error) { + if ctx == nil { + ctx = context.Background() + } + // 1. Resolve configuration (env vars override if config is empty) + introspectionURL := cfg.IntrospectionURL + if introspectionURL == "" { + introspectionURL = os.Getenv("OAUTH_TOKEN_INFO_ENDPOINT") + } + clientID := cfg.ClientID + if clientID == "" { + clientID = os.Getenv("OAUTH_CLIENT_ID") + } + clientSecret := cfg.ClientSecret + if clientSecret == "" { + clientSecret = os.Getenv("OAUTH_CLIENT_SECRET") + } + + if introspectionURL == "" { + return nil, fmt.Errorf("introspection URL not configured (set introspection-url or OAUTH_TOKEN_INFO_ENDPOINT)") + } + + userIDField := cfg.UserIDField + if userIDField == "" { + userIDField = os.Getenv("OAUTH_USER_ID_FIELD_NAME") + if userIDField == "" { + userIDField = "sub" + } + } + + emailField := cfg.EmailField + if emailField == "" { + emailField = "email" + } + + providerID := strings.TrimSpace(cfg.ProviderID) + if providerID == "" { + providerID = "generic" + } + + // 2. Determine if it's an introspection endpoint (RFC 7662) or UserInfo (GET) + // Heuristic: treat URLs containing "introspect" with client credentials as RFC 7662 introspection. + isIntrospection := strings.Contains(introspectionURL, "introspect") && clientID != "" && clientSecret != "" + + var ( + method string + contentType string + accept string + authorization string + encodedForm string + useRequestBody bool + ) + if isIntrospection { + log.Debug("Using OAuth2 introspection endpoint (POST)") + data := url.Values{} + data.Set("token", token) + + method = http.MethodPost + encodedForm = data.Encode() + useRequestBody = true + contentType = "application/x-www-form-urlencoded" + accept = "application/json" + + // Basic Auth for client credentials + auth := clientID + ":" + clientSecret + basic := base64.StdEncoding.EncodeToString([]byte(auth)) + authorization = "Basic " + basic + + } else { + log.Debug("Using generic token info endpoint (GET)") + method = http.MethodGet + authorization = "Bearer " + token + contentType = "application/json" + accept = "application/json" + } + + // 3. Execute request (hardened retries + response size caps). + client := &http.Client{Timeout: 30 * time.Second} + status, _, body, err := oauthhttp.Do( + ctx, + client, + func() (*http.Request, error) { + var bodyReader io.Reader + if useRequestBody { + bodyReader = strings.NewReader(encodedForm) + } + req, errReq := http.NewRequestWithContext(ctx, method, introspectionURL, bodyReader) + if errReq != nil { + return nil, errReq + } + if strings.TrimSpace(authorization) != "" { + req.Header.Set("Authorization", authorization) + } + if strings.TrimSpace(contentType) != "" { + req.Header.Set("Content-Type", contentType) + } + if strings.TrimSpace(accept) != "" { + req.Header.Set("Accept", accept) + } + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return nil, fmt.Errorf("token validation request failed: %w", err) + } + if status >= http.StatusBadRequest { + msg := strings.TrimSpace(string(body)) + if msg == "" { + msg = fmt.Sprintf("status %d", status) + } + if err != nil { + return nil, fmt.Errorf("token validation failed: %s: %w", msg, err) + } + return nil, fmt.Errorf("token validation failed: %s", msg) + } + if err != nil { + return nil, fmt.Errorf("token validation request failed: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // 4. Validate response + if isIntrospection { + active, ok := result["active"].(bool) + if !ok || !active { + return nil, fmt.Errorf("token is not active") + } + } + + // 5. Map to Auth struct + auth := &coreauth.Auth{ + Provider: providerID, + Metadata: make(map[string]any), + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + Status: coreauth.StatusActive, + } + + // Extract ID + if id, ok := result[userIDField].(string); ok { + auth.ID = id + } else if id, ok := result["id"].(string); ok { + auth.ID = id // Fallback + } else { + // No stable user identifier available; fall back to a deterministic token fingerprint. + auth.ID = fmt.Sprintf("%s-%s", providerID, tokenFingerprint(token)) + } + + // Extract Email + if email, ok := result[emailField].(string); ok { + auth.Metadata["email"] = email + auth.Label = email + } + + // Store raw response in metadata for flexibility + for k, v := range result { + auth.Metadata[k] = v + } + + return auth, nil +} diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index fa9f38c3e..ee2678051 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -9,10 +9,12 @@ import ( "io" "net/http" "net/url" + "os" "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) @@ -28,10 +30,21 @@ const ( iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey" // Client credentials provided by iFlow for the Code Assist integration. - iFlowOAuthClientID = "10009311001" - iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" + iFlowOAuthClientID = "10009311001" + // Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var) + defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" ) +// getIFlowClientSecret returns the iFlow OAuth client secret. +// It first checks the IFLOW_CLIENT_SECRET environment variable, +// falling back to the default value if not set. +func getIFlowClientSecret() string { + if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" { + return secret + } + return defaultIFlowClientSecret +} + // DefaultAPIBaseURL is the canonical chat completions endpoint. const DefaultAPIBaseURL = "https://apis.iflow.cn/v1" @@ -49,7 +62,7 @@ type IFlowAuth struct { // NewIFlowAuth constructs a new IFlowAuth with proxy-aware transport. func NewIFlowAuth(cfg *config.Config) *IFlowAuth { client := &http.Client{Timeout: 30 * time.Second} - return &IFlowAuth{httpClient: util.SetProxy(&cfg.SDKConfig, client)} + return &IFlowAuth{httpClient: util.SetOAuthProxy(&cfg.SDKConfig, client)} } // AuthorizationURL builds the authorization URL and matching redirect URI. @@ -72,14 +85,8 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR form.Set("code", code) form.Set("redirect_uri", redirectURI) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) - - req, err := ia.newTokenRequest(ctx, form) - if err != nil { - return nil, err - } - - return ia.doTokenRequest(ctx, req) + form.Set("client_secret", getIFlowClientSecret()) + return ia.doTokenRequest(ctx, form) } // RefreshTokens exchanges a refresh token for a new access token. @@ -88,44 +95,47 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I form.Set("grant_type", "refresh_token") form.Set("refresh_token", refreshToken) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) - - req, err := ia.newTokenRequest(ctx, form) - if err != nil { - return nil, err - } - - return ia.doTokenRequest(ctx, req) -} - -func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, iFlowOAuthTokenEndpoint, strings.NewReader(form.Encode())) - if err != nil { - return nil, fmt.Errorf("iflow token: create request failed: %w", err) - } - - basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret)) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Basic "+basic) - return req, nil + form.Set("client_secret", getIFlowClientSecret()) + return ia.doTokenRequest(ctx, form) } -func (ia *IFlowAuth) doTokenRequest(ctx context.Context, req *http.Request) (*IFlowTokenData, error) { - resp, err := ia.httpClient.Do(req) - if err != nil { +func (ia *IFlowAuth) doTokenRequest(ctx context.Context, form url.Values) (*IFlowTokenData, error) { + if ctx == nil { + ctx = context.Background() + } + + encoded := form.Encode() + basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret())) + + status, _, body, err := oauthhttp.Do( + ctx, + ia.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, iFlowOAuthTokenEndpoint, strings.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("iflow token: create request failed: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Basic "+basic) + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("iflow token: request failed: %w", err) } - defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("iflow token: read response failed: %w", err) + if status != http.StatusOK { + log.Debugf("iflow token request failed: status=%d body=%s", status, string(body)) + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("iflow token: %d %s: %w", status, msg, err) + } + return nil, fmt.Errorf("iflow token: %d %s", status, msg) } - - if resp.StatusCode != http.StatusOK { - log.Debugf("iflow token request failed: status=%d body=%s", resp.StatusCode, string(body)) - return nil, fmt.Errorf("iflow token: %d %s", resp.StatusCode, strings.TrimSpace(string(body))) + if err != nil { + return nil, fmt.Errorf("iflow token: request failed: %w", err) } var tokenResp IFlowTokenResponse @@ -171,28 +181,38 @@ func (ia *IFlowAuth) FetchUserInfo(ctx context.Context, accessToken string) (*us if strings.TrimSpace(accessToken) == "" { return nil, fmt.Errorf("iflow api key: access token is empty") } - - endpoint := fmt.Sprintf("%s?accessToken=%s", iFlowUserInfoEndpoint, url.QueryEscape(accessToken)) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, fmt.Errorf("iflow api key: create request failed: %w", err) + if ctx == nil { + ctx = context.Background() } - req.Header.Set("Accept", "application/json") - resp, err := ia.httpClient.Do(req) - if err != nil { + endpoint := fmt.Sprintf("%s?accessToken=%s", iFlowUserInfoEndpoint, url.QueryEscape(accessToken)) + status, _, body, err := oauthhttp.Do( + ctx, + ia.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("iflow api key: create request failed: %w", err) + } + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("iflow api key: request failed: %w", err) } - defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("iflow api key: read response failed: %w", err) + if status != http.StatusOK { + log.Debugf("iflow api key failed: status=%d body=%s", status, string(body)) + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("iflow api key: %d %s: %w", status, msg, err) + } + return nil, fmt.Errorf("iflow api key: %d %s", status, msg) } - - if resp.StatusCode != http.StatusOK { - log.Debugf("iflow api key failed: status=%d body=%s", resp.StatusCode, string(body)) - return nil, fmt.Errorf("iflow api key: %d %s", resp.StatusCode, strings.TrimSpace(string(body))) + if err != nil { + return nil, fmt.Errorf("iflow api key: request failed: %w", err) } var result userInfoResponse diff --git a/internal/auth/iflow/oauth_provider.go b/internal/auth/iflow/oauth_provider.go new file mode 100644 index 000000000..5a76d318f --- /dev/null +++ b/internal/auth/iflow/oauth_provider.go @@ -0,0 +1,128 @@ +package iflow + +import ( + "context" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" +) + +// OAuthProvider adapts IFlowAuth to the shared oauthflow.ProviderOAuth interface. +type OAuthProvider struct { + auth *IFlowAuth +} + +func NewOAuthProvider(auth *IFlowAuth) *OAuthProvider { + return &OAuthProvider{auth: auth} +} + +func (p *OAuthProvider) Provider() string { + return "iflow" +} + +func (p *OAuthProvider) AuthorizeURL(session oauthflow.OAuthSession) (string, oauthflow.OAuthSession, error) { + if p == nil || p.auth == nil { + return "", session, fmt.Errorf("iflow oauth provider: auth is nil") + } + redirectURI := strings.TrimSpace(session.RedirectURI) + if redirectURI == "" { + return "", session, fmt.Errorf("iflow oauth provider: redirect URI is empty") + } + parsed, err := url.Parse(redirectURI) + if err != nil { + return "", session, fmt.Errorf("iflow oauth provider: parse redirect URI: %w", err) + } + portStr := parsed.Port() + port, err := strconv.Atoi(portStr) + if err != nil || port <= 0 { + return "", session, fmt.Errorf("iflow oauth provider: invalid redirect URI port: %q", portStr) + } + + authURL, resolvedRedirectURI := p.auth.AuthorizationURL(session.State, port) + session.RedirectURI = resolvedRedirectURI + return authURL, session, nil +} + +func (p *OAuthProvider) ExchangeCode(ctx context.Context, session oauthflow.OAuthSession, code string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("iflow oauth provider: auth is nil") + } + data, err := p.auth.ExchangeCodeForTokens(ctx, code, session.RedirectURI) + if err != nil { + return nil, err + } + if data == nil { + return nil, fmt.Errorf("iflow oauth provider: token result is nil") + } + + meta := map[string]any{} + if email := strings.TrimSpace(data.Email); email != "" { + meta["email"] = email + } + if apiKey := strings.TrimSpace(data.APIKey); apiKey != "" { + meta["api_key"] = apiKey + } + if scope := strings.TrimSpace(data.Scope); scope != "" { + meta["scope"] = scope + } + + tokenType := strings.TrimSpace(data.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(data.AccessToken), + RefreshToken: strings.TrimSpace(data.RefreshToken), + ExpiresAt: strings.TrimSpace(data.Expire), + TokenType: tokenType, + Metadata: meta, + }, nil +} + +func (p *OAuthProvider) Refresh(ctx context.Context, refreshToken string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("iflow oauth provider: auth is nil") + } + data, err := p.auth.RefreshTokens(ctx, refreshToken) + if err != nil { + return nil, err + } + if data == nil { + return nil, fmt.Errorf("iflow oauth provider: refresh result is nil") + } + + meta := map[string]any{} + if email := strings.TrimSpace(data.Email); email != "" { + meta["email"] = email + } + if apiKey := strings.TrimSpace(data.APIKey); apiKey != "" { + meta["api_key"] = apiKey + } + if scope := strings.TrimSpace(data.Scope); scope != "" { + meta["scope"] = scope + } + + tokenType := strings.TrimSpace(data.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(data.AccessToken), + RefreshToken: strings.TrimSpace(data.RefreshToken), + ExpiresAt: strings.TrimSpace(data.Expire), + TokenType: tokenType, + Metadata: meta, + }, nil +} + +// Revoke invalidates the given token at the iFlow provider. +// Note: iFlow does not currently provide a public token revocation endpoint, +// so this method returns ErrRevokeNotSupported. +func (p *OAuthProvider) Revoke(ctx context.Context, token string) error { + return oauthflow.ErrRevokeNotSupported +} diff --git a/internal/auth/qwen/device_oauth_provider.go b/internal/auth/qwen/device_oauth_provider.go new file mode 100644 index 000000000..0c80d177a --- /dev/null +++ b/internal/auth/qwen/device_oauth_provider.go @@ -0,0 +1,127 @@ +package qwen + +import ( + "context" + "fmt" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" +) + +// DeviceOAuthProvider adapts Qwen device OAuth to the shared oauthflow.ProviderDeviceOAuth interface. +type DeviceOAuthProvider struct { + auth *QwenAuth +} + +func NewDeviceOAuthProvider(auth *QwenAuth) *DeviceOAuthProvider { + return &DeviceOAuthProvider{auth: auth} +} + +func (p *DeviceOAuthProvider) Provider() string { + return "qwen" +} + +func (p *DeviceOAuthProvider) DeviceAuthorize(ctx context.Context) (*oauthflow.DeviceCodeResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("qwen device oauth provider: auth is nil") + } + flow, err := p.auth.InitiateDeviceFlow(ctx) + if err != nil { + return nil, err + } + if flow == nil { + return nil, fmt.Errorf("qwen device oauth provider: device flow is nil") + } + return &oauthflow.DeviceCodeResult{ + DeviceCode: strings.TrimSpace(flow.DeviceCode), + UserCode: strings.TrimSpace(flow.UserCode), + VerificationURI: strings.TrimSpace(flow.VerificationURI), + VerificationURIComplete: strings.TrimSpace(flow.VerificationURIComplete), + ExpiresIn: flow.ExpiresIn, + Interval: flow.Interval, + CodeVerifier: strings.TrimSpace(flow.CodeVerifier), + }, nil +} + +func (p *DeviceOAuthProvider) DevicePoll(ctx context.Context, device *oauthflow.DeviceCodeResult) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("qwen device oauth provider: auth is nil") + } + if device == nil { + return nil, fmt.Errorf("qwen device oauth provider: device code is nil") + } + + flow := &DeviceFlow{ + DeviceCode: strings.TrimSpace(device.DeviceCode), + UserCode: strings.TrimSpace(device.UserCode), + VerificationURI: strings.TrimSpace(device.VerificationURI), + VerificationURIComplete: strings.TrimSpace(device.VerificationURIComplete), + ExpiresIn: device.ExpiresIn, + Interval: device.Interval, + CodeVerifier: strings.TrimSpace(device.CodeVerifier), + } + + tokenData, err := p.auth.PollForToken(ctx, flow) + if err != nil { + return nil, err + } + if tokenData == nil { + return nil, fmt.Errorf("qwen device oauth provider: token result is nil") + } + + meta := map[string]any{} + if strings.TrimSpace(tokenData.ResourceURL) != "" { + meta["resource_url"] = strings.TrimSpace(tokenData.ResourceURL) + } + + tokenType := strings.TrimSpace(tokenData.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(tokenData.AccessToken), + RefreshToken: strings.TrimSpace(tokenData.RefreshToken), + ExpiresAt: strings.TrimSpace(tokenData.Expire), + TokenType: tokenType, + Metadata: meta, + }, nil +} + +func (p *DeviceOAuthProvider) Refresh(ctx context.Context, refreshToken string) (*oauthflow.TokenResult, error) { + if p == nil || p.auth == nil { + return nil, fmt.Errorf("qwen device oauth provider: auth is nil") + } + refreshToken = strings.TrimSpace(refreshToken) + if refreshToken == "" { + return nil, fmt.Errorf("qwen device oauth provider: refresh token is empty") + } + tokenData, err := p.auth.RefreshTokens(ctx, refreshToken) + if err != nil { + return nil, err + } + if tokenData == nil { + return nil, fmt.Errorf("qwen device oauth provider: refresh result is nil") + } + + meta := map[string]any{} + if strings.TrimSpace(tokenData.ResourceURL) != "" { + meta["resource_url"] = strings.TrimSpace(tokenData.ResourceURL) + } + tokenType := strings.TrimSpace(tokenData.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(tokenData.AccessToken), + RefreshToken: refreshToken, + ExpiresAt: strings.TrimSpace(tokenData.Expire), + TokenType: tokenType, + Metadata: meta, + }, nil +} + +func (p *DeviceOAuthProvider) Revoke(ctx context.Context, token string) error { + return oauthflow.ErrRevokeNotSupported +} diff --git a/internal/auth/qwen/qwen_auth.go b/internal/auth/qwen/qwen_auth.go index cb58b86d3..d0054fe5d 100644 --- a/internal/auth/qwen/qwen_auth.go +++ b/internal/auth/qwen/qwen_auth.go @@ -7,13 +7,14 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io" "net/http" "net/url" "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) @@ -85,7 +86,7 @@ type QwenAuth struct { // NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client. func NewQwenAuth(cfg *config.Config) *QwenAuth { return &QwenAuth{ - httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}), + httpClient: util.SetOAuthProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), } } @@ -116,40 +117,46 @@ func (qa *QwenAuth) generatePKCEPair() (string, string, error) { // RefreshTokens exchanges a refresh token for a new access token. func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) { + if ctx == nil { + ctx = context.Background() + } data := url.Values{} data.Set("grant_type", "refresh_token") data.Set("refresh_token", refreshToken) data.Set("client_id", QwenOAuthClientID) - req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - resp, err := qa.httpClient.Do(req) - - // resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data) - if err != nil { + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + ctx, + qa.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, QwenOAuthTokenEndpoint, strings.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("token refresh request failed: %w", err) } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - if resp.StatusCode != http.StatusOK { + if status != http.StatusOK { var errorData map[string]interface{} if err = json.Unmarshal(body, &errorData); err == nil { return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"]) } - return nil, fmt.Errorf("token refresh failed: %s", string(body)) + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("token refresh failed: %s: %w", msg, err) + } + return nil, fmt.Errorf("token refresh failed: %s", msg) + } + if err != nil { + return nil, fmt.Errorf("token refresh request failed: %w", err) } var tokenData QwenTokenResponse @@ -168,6 +175,9 @@ func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Qw // InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details. func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) { + if ctx == nil { + ctx = context.Background() + } // Generate PKCE code verifier and challenge codeVerifier, codeChallenge, err := qa.generatePKCEPair() if err != nil { @@ -180,31 +190,34 @@ func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) data.Set("code_challenge", codeChallenge) data.Set("code_challenge_method", "S256") - req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - resp, err := qa.httpClient.Do(req) - - // resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data) - if err != nil { + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + ctx, + qa.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, QwenOAuthDeviceCodeEndpoint, strings.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { return nil, fmt.Errorf("device authorization request failed: %w", err) } - defer func() { - _ = resp.Body.Close() - }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + if status != http.StatusOK { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("device authorization failed: %d. Response: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("device authorization failed: %d. Response: %s", status, msg) } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body)) + if err != nil { + return nil, fmt.Errorf("device authorization request failed: %w", err) } var result DeviceFlow @@ -224,90 +237,128 @@ func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) } // PollForToken polls the token endpoint with the device code to obtain an access token. -func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) { - pollInterval := 5 * time.Second - maxAttempts := 60 // 5 minutes max +func (qa *QwenAuth) PollForToken(ctx context.Context, deviceFlow *DeviceFlow) (*QwenTokenData, error) { + if deviceFlow == nil { + return nil, fmt.Errorf("device flow is nil") + } + deviceCode := strings.TrimSpace(deviceFlow.DeviceCode) + if deviceCode == "" { + return nil, fmt.Errorf("device code is empty") + } + codeVerifier := strings.TrimSpace(deviceFlow.CodeVerifier) + if codeVerifier == "" { + return nil, fmt.Errorf("code verifier is empty") + } + if ctx == nil { + ctx = context.Background() + } - for attempt := 0; attempt < maxAttempts; attempt++ { + device := &oauthflow.DeviceCodeResult{ + DeviceCode: deviceCode, + ExpiresIn: deviceFlow.ExpiresIn, + Interval: deviceFlow.Interval, + CodeVerifier: codeVerifier, + } + + token, err := oauthflow.PollDeviceToken(ctx, device, func(pollCtx context.Context) (*oauthflow.TokenResult, error) { data := url.Values{} data.Set("grant_type", QwenOAuthGrantType) data.Set("client_id", QwenOAuthClientID) data.Set("device_code", deviceCode) data.Set("code_verifier", codeVerifier) - resp, err := http.PostForm(QwenOAuthTokenEndpoint, data) + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + pollCtx, + qa.httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(pollCtx, http.MethodPost, QwenOAuthTokenEndpoint, strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) if err != nil { - fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err) - time.Sleep(pollInterval) - continue + return nil, fmt.Errorf("%w: %w", oauthflow.ErrTransient, err) } - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err) - time.Sleep(pollInterval) - continue + if status == http.StatusOK { + var response QwenTokenResponse + if err = json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + meta := map[string]any{} + if strings.TrimSpace(response.ResourceURL) != "" { + meta["resource_url"] = response.ResourceURL + } + tokenType := strings.TrimSpace(response.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + return &oauthflow.TokenResult{ + AccessToken: response.AccessToken, + RefreshToken: response.RefreshToken, + TokenType: tokenType, + ExpiresAt: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339), + Metadata: meta, + }, nil } - if resp.StatusCode != http.StatusOK { - // Parse the response as JSON to check for OAuth RFC 8628 standard errors - var errorData map[string]interface{} - if err = json.Unmarshal(body, &errorData); err == nil { - // According to OAuth RFC 8628, handle standard polling responses - if resp.StatusCode == http.StatusBadRequest { - errorType, _ := errorData["error"].(string) - switch errorType { - case "authorization_pending": - // User has not yet approved the authorization request. Continue polling. - fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts) - time.Sleep(pollInterval) - continue - case "slow_down": - // Client is polling too frequently. Increase poll interval. - pollInterval = time.Duration(float64(pollInterval) * 1.5) - if pollInterval > 10*time.Second { - pollInterval = 10 * time.Second - } - fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval) - time.Sleep(pollInterval) - continue - case "expired_token": - return nil, fmt.Errorf("device code expired. Please restart the authentication process") - case "access_denied": - return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process") - } + // Parse the response as JSON to check for OAuth RFC 8628 standard errors. + var errorData struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + if err = json.Unmarshal(body, &errorData); err == nil { + if status == http.StatusBadRequest { + switch strings.TrimSpace(errorData.Error) { + case "authorization_pending": + return nil, oauthflow.ErrAuthorizationPending + case "slow_down": + return nil, oauthflow.ErrSlowDown + case "expired_token": + return nil, oauthflow.ErrDeviceCodeExpired + case "access_denied": + return nil, oauthflow.ErrAccessDenied } - - // For other errors, return with proper error information - errorType, _ := errorData["error"].(string) - errorDesc, _ := errorData["error_description"].(string) - return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc) } - - // If JSON parsing fails, fall back to text response - return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body)) - } - // log.Debugf("%s", string(body)) - // Success - parse token data - var response QwenTokenResponse - if err = json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("failed to parse token response: %w", err) + if strings.TrimSpace(errorData.Error) != "" { + return nil, fmt.Errorf("device token poll failed: %s - %s", errorData.Error, errorData.ErrorDescription) + } } - // Convert to QwenTokenData format and save - tokenData := &QwenTokenData{ - AccessToken: response.AccessToken, - RefreshToken: response.RefreshToken, - TokenType: response.TokenType, - ResourceURL: response.ResourceURL, - Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339), + trimmed := strings.TrimSpace(string(body)) + if status == http.StatusTooManyRequests || status >= http.StatusInternalServerError { + return nil, fmt.Errorf("%w: status %d: %s", oauthflow.ErrTransient, status, trimmed) } + return nil, fmt.Errorf("device token poll failed: status %d: %s", status, trimmed) + }) + if err != nil { + return nil, err + } + if token == nil { + return nil, fmt.Errorf("token result is nil") + } - return tokenData, nil + tokenData := &QwenTokenData{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + TokenType: token.TokenType, + Expire: token.ExpiresAt, + } + if token.Metadata != nil { + if raw, ok := token.Metadata["resource_url"]; ok { + if val, okStr := raw.(string); okStr { + tokenData.ResourceURL = strings.TrimSpace(val) + } + } } - return nil, fmt.Errorf("authentication timeout. Please restart the authentication process") + return tokenData, nil } // RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure. diff --git a/internal/auth/refresher/worker.go b/internal/auth/refresher/worker.go new file mode 100644 index 000000000..725091222 --- /dev/null +++ b/internal/auth/refresher/worker.go @@ -0,0 +1,337 @@ +// Package refresher provides a background worker for proactive OAuth token refresh. +// It monitors registered tokens and refreshes them before expiry to prevent +// authentication failures during active use. +package refresher + +import ( + "context" + "errors" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// DefaultRefreshLeadTime is the default duration before expiry to trigger refresh. +const DefaultRefreshLeadTime = 10 * time.Minute + +// DefaultCheckInterval is the default interval between refresh checks. +const DefaultCheckInterval = 1 * time.Minute + +// ErrWorkerStopped is returned when operations are attempted on a stopped worker. +var ErrWorkerStopped = errors.New("refresher: worker is stopped") + +// Token represents a refreshable OAuth token. +type Token struct { + // ID is a unique identifier for this token (e.g., auth ID). + ID string + + // Provider identifies the OAuth provider (e.g., "claude", "codex"). + Provider string + + // RefreshToken is the OAuth refresh token. + RefreshToken string + + // ExpiresAt is the access token expiration time. + ExpiresAt time.Time + + // LastRefresh is the timestamp of the last refresh attempt. + LastRefresh time.Time + + // RefreshError captures the last refresh error, if any. + RefreshError error +} + +// NeedsRefresh checks if the token should be refreshed based on lead time. +func (t *Token) NeedsRefresh(leadTime time.Duration) bool { + if t == nil || t.RefreshToken == "" { + return false + } + if t.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(leadTime).After(t.ExpiresAt) +} + +// Refresher is a function that performs the actual token refresh. +// It should return the new expiration time on success. +type Refresher func(ctx context.Context, token *Token) (newExpiresAt time.Time, err error) + +// Hook provides callbacks for refresh events. +type Hook interface { + // OnRefreshSuccess is called when a token is successfully refreshed. + OnRefreshSuccess(token *Token, newExpiresAt time.Time) + + // OnRefreshError is called when a token refresh fails. + OnRefreshError(token *Token, err error) +} + +// NoopHook is a no-op implementation of Hook. +type NoopHook struct{} + +func (NoopHook) OnRefreshSuccess(*Token, time.Time) {} +func (NoopHook) OnRefreshError(*Token, error) {} + +// Config configures the refresh worker. +type Config struct { + // RefreshLeadTime is how far before expiry to trigger refresh. + // Default: 10 minutes. + RefreshLeadTime time.Duration + + // CheckInterval is how often to check for tokens needing refresh. + // Default: 1 minute. + CheckInterval time.Duration + + // MaxConcurrency limits concurrent refresh operations. + // Default: 5. + MaxConcurrency int + + // RetryDelay is the delay before retrying a failed refresh. + // Default: 5 minutes. + RetryDelay time.Duration +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() Config { + return Config{ + RefreshLeadTime: DefaultRefreshLeadTime, + CheckInterval: DefaultCheckInterval, + MaxConcurrency: 5, + RetryDelay: 5 * time.Minute, + } +} + +// Worker manages background token refresh operations. +type Worker struct { + config Config + refresher Refresher + hook Hook + + mu sync.RWMutex + tokens map[string]*Token + running bool + cancel context.CancelFunc + wg sync.WaitGroup +} + +// NewWorker creates a new refresh worker. +func NewWorker(refresher Refresher, config Config, hook Hook) *Worker { + if refresher == nil { + panic("refresher: refresher function cannot be nil") + } + if hook == nil { + hook = NoopHook{} + } + if config.RefreshLeadTime <= 0 { + config.RefreshLeadTime = DefaultRefreshLeadTime + } + if config.CheckInterval <= 0 { + config.CheckInterval = DefaultCheckInterval + } + if config.MaxConcurrency <= 0 { + config.MaxConcurrency = 5 + } + if config.RetryDelay <= 0 { + config.RetryDelay = 5 * time.Minute + } + + return &Worker{ + config: config, + refresher: refresher, + hook: hook, + tokens: make(map[string]*Token), + } +} + +// Start begins the background refresh loop. +func (w *Worker) Start() { + w.mu.Lock() + if w.running { + w.mu.Unlock() + return + } + w.running = true + ctx, cancel := context.WithCancel(context.Background()) + w.cancel = cancel + w.mu.Unlock() + + w.wg.Add(1) + go w.loop(ctx) +} + +// Stop gracefully stops the worker, waiting for in-flight refreshes. +func (w *Worker) Stop() { + w.mu.Lock() + if !w.running { + w.mu.Unlock() + return + } + w.running = false + if w.cancel != nil { + w.cancel() + } + w.mu.Unlock() + + w.wg.Wait() +} + +// Register adds or updates a token for monitoring. +func (w *Worker) Register(token *Token) error { + if token == nil || token.ID == "" { + return nil + } + w.mu.Lock() + defer w.mu.Unlock() + if !w.running { + return ErrWorkerStopped + } + + // Clone to avoid external mutation + clone := &Token{ + ID: token.ID, + Provider: token.Provider, + RefreshToken: token.RefreshToken, + ExpiresAt: token.ExpiresAt, + LastRefresh: token.LastRefresh, + RefreshError: token.RefreshError, + } + w.tokens[token.ID] = clone + return nil +} + +// Unregister removes a token from monitoring. +func (w *Worker) Unregister(tokenID string) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.tokens, tokenID) +} + +// Get returns a copy of the token state, or nil if not found. +func (w *Worker) Get(tokenID string) *Token { + w.mu.RLock() + defer w.mu.RUnlock() + t, ok := w.tokens[tokenID] + if !ok || t == nil { + return nil + } + return &Token{ + ID: t.ID, + Provider: t.Provider, + RefreshToken: t.RefreshToken, + ExpiresAt: t.ExpiresAt, + LastRefresh: t.LastRefresh, + RefreshError: t.RefreshError, + } +} + +// TokenCount returns the number of registered tokens. +func (w *Worker) TokenCount() int { + w.mu.RLock() + defer w.mu.RUnlock() + return len(w.tokens) +} + +func (w *Worker) loop(ctx context.Context) { + defer w.wg.Done() + + ticker := time.NewTicker(w.config.CheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + w.checkAndRefresh(ctx) + } + } +} + +func (w *Worker) checkAndRefresh(ctx context.Context) { + // Get tokens that need refresh + var candidates []*Token + now := time.Now() + + w.mu.RLock() + for _, t := range w.tokens { + if t == nil || t.RefreshToken == "" { + continue + } + // Skip if recently tried and failed + if t.RefreshError != nil && now.Sub(t.LastRefresh) < w.config.RetryDelay { + continue + } + if t.NeedsRefresh(w.config.RefreshLeadTime) { + candidates = append(candidates, &Token{ + ID: t.ID, + Provider: t.Provider, + RefreshToken: t.RefreshToken, + ExpiresAt: t.ExpiresAt, + LastRefresh: t.LastRefresh, + }) + } + } + w.mu.RUnlock() + + if len(candidates) == 0 { + return + } + + log.Debugf("refresher: %d token(s) need refresh", len(candidates)) + + // Limit concurrency + sem := make(chan struct{}, w.config.MaxConcurrency) + var wg sync.WaitGroup + + for _, token := range candidates { + select { + case <-ctx.Done(): + return + case sem <- struct{}{}: + } + + wg.Add(1) + go func(t *Token) { + defer wg.Done() + defer func() { <-sem }() + w.refreshToken(ctx, t) + }(token) + } + + wg.Wait() +} + +func (w *Worker) refreshToken(ctx context.Context, token *Token) { + if token == nil { + return + } + + log.Debugf("refresher: refreshing token %s (%s)", token.ID, token.Provider) + + newExpiry, err := w.refresher(ctx, token) + now := time.Now() + + w.mu.Lock() + t, exists := w.tokens[token.ID] + if !exists || t == nil { + w.mu.Unlock() + return + } + + t.LastRefresh = now + + if err != nil { + t.RefreshError = err + w.mu.Unlock() + log.Warnf("refresher: failed to refresh token %s: %v", token.ID, err) + w.hook.OnRefreshError(token, err) + return + } + + t.ExpiresAt = newExpiry + t.RefreshError = nil + w.mu.Unlock() + + log.Infof("refresher: successfully refreshed token %s (expires: %s)", token.ID, newExpiry.Format(time.RFC3339)) + w.hook.OnRefreshSuccess(token, newExpiry) +} diff --git a/internal/auth/refresher/worker_test.go b/internal/auth/refresher/worker_test.go new file mode 100644 index 000000000..25f68d840 --- /dev/null +++ b/internal/auth/refresher/worker_test.go @@ -0,0 +1,328 @@ +package refresher + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" +) + +func TestToken_NeedsRefresh(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + token *Token + leadTime time.Duration + want bool + }{ + { + name: "nil token", + token: nil, + leadTime: 10 * time.Minute, + want: false, + }, + { + name: "no refresh token", + token: &Token{ID: "1", ExpiresAt: now.Add(5 * time.Minute)}, + leadTime: 10 * time.Minute, + want: false, + }, + { + name: "zero expiry", + token: &Token{ID: "1", RefreshToken: "refresh"}, + leadTime: 10 * time.Minute, + want: false, + }, + { + name: "expires within lead time", + token: &Token{ID: "1", RefreshToken: "refresh", ExpiresAt: now.Add(5 * time.Minute)}, + leadTime: 10 * time.Minute, + want: true, + }, + { + name: "expires after lead time", + token: &Token{ID: "1", RefreshToken: "refresh", ExpiresAt: now.Add(30 * time.Minute)}, + leadTime: 10 * time.Minute, + want: false, + }, + { + name: "already expired", + token: &Token{ID: "1", RefreshToken: "refresh", ExpiresAt: now.Add(-1 * time.Minute)}, + leadTime: 10 * time.Minute, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.token.NeedsRefresh(tt.leadTime); got != tt.want { + t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWorker_RegisterUnregister(t *testing.T) { + refresher := func(ctx context.Context, token *Token) (time.Time, error) { + return time.Now().Add(1 * time.Hour), nil + } + + w := NewWorker(refresher, DefaultConfig(), nil) + w.Start() + defer w.Stop() + + token := &Token{ + ID: "test-1", + Provider: "test", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(30 * time.Minute), + } + + // Register + if err := w.Register(token); err != nil { + t.Fatalf("Register() error = %v", err) + } + if w.TokenCount() != 1 { + t.Errorf("TokenCount() = %d, want 1", w.TokenCount()) + } + + // Get + got := w.Get("test-1") + if got == nil { + t.Fatal("Get() returned nil") + } + if got.ID != token.ID { + t.Errorf("Get().ID = %v, want %v", got.ID, token.ID) + } + + // Unregister + w.Unregister("test-1") + if w.TokenCount() != 0 { + t.Errorf("TokenCount() after unregister = %d, want 0", w.TokenCount()) + } + if got := w.Get("test-1"); got != nil { + t.Errorf("Get() after unregister = %v, want nil", got) + } +} + +func TestWorker_RefreshTriggered(t *testing.T) { + var refreshCount atomic.Int32 + newExpiry := time.Now().Add(2 * time.Hour) + + refresher := func(ctx context.Context, token *Token) (time.Time, error) { + refreshCount.Add(1) + return newExpiry, nil + } + + config := Config{ + RefreshLeadTime: 30 * time.Minute, + CheckInterval: 50 * time.Millisecond, + MaxConcurrency: 5, + RetryDelay: 1 * time.Second, + } + + var hookCalled atomic.Bool + hook := &testHook{ + onSuccess: func(token *Token, exp time.Time) { + hookCalled.Store(true) + }, + } + + w := NewWorker(refresher, config, hook) + w.Start() + + // Register token that needs refresh (expires in 10 minutes, lead time is 30 minutes) + token := &Token{ + ID: "test-refresh", + Provider: "test", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(10 * time.Minute), + } + if err := w.Register(token); err != nil { + t.Fatalf("Register() error = %v", err) + } + + // Wait for refresh to be triggered + time.Sleep(200 * time.Millisecond) + w.Stop() + + if refreshCount.Load() == 0 { + t.Error("refresh was not triggered") + } + if !hookCalled.Load() { + t.Error("hook.OnRefreshSuccess was not called") + } + + // Check updated expiry + got := w.Get("test-refresh") + if got == nil { + t.Fatal("Get() returned nil after refresh") + } + // The expiry should be updated + if got.ExpiresAt.Before(time.Now().Add(1 * time.Hour)) { + t.Errorf("expiry was not updated: got %v", got.ExpiresAt) + } +} + +func TestWorker_RefreshError(t *testing.T) { + refreshErr := errors.New("refresh failed") + var refreshCount atomic.Int32 + + refresher := func(ctx context.Context, token *Token) (time.Time, error) { + refreshCount.Add(1) + return time.Time{}, refreshErr + } + + config := Config{ + RefreshLeadTime: 30 * time.Minute, + CheckInterval: 50 * time.Millisecond, + MaxConcurrency: 5, + RetryDelay: 10 * time.Second, // Long retry delay + } + + var errorHookCalled atomic.Bool + hook := &testHook{ + onError: func(token *Token, err error) { + errorHookCalled.Store(true) + }, + } + + w := NewWorker(refresher, config, hook) + w.Start() + + token := &Token{ + ID: "test-error", + Provider: "test", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(10 * time.Minute), + } + if err := w.Register(token); err != nil { + t.Fatalf("Register() error = %v", err) + } + + // Wait for first refresh attempt + time.Sleep(200 * time.Millisecond) + w.Stop() + + if refreshCount.Load() == 0 { + t.Error("refresh was not attempted") + } + if !errorHookCalled.Load() { + t.Error("hook.OnRefreshError was not called") + } + + // Check error is recorded + got := w.Get("test-error") + if got == nil { + t.Fatal("Get() returned nil") + } + if got.RefreshError == nil { + t.Error("RefreshError was not recorded") + } +} + +func TestWorker_ConcurrencyLimit(t *testing.T) { + var concurrent atomic.Int32 + var maxConcurrent atomic.Int32 + + refresher := func(ctx context.Context, token *Token) (time.Time, error) { + current := concurrent.Add(1) + // Track max concurrent + for { + old := maxConcurrent.Load() + if current <= old || maxConcurrent.CompareAndSwap(old, current) { + break + } + } + time.Sleep(50 * time.Millisecond) + concurrent.Add(-1) + return time.Now().Add(1 * time.Hour), nil + } + + config := Config{ + RefreshLeadTime: 30 * time.Minute, + CheckInterval: 10 * time.Millisecond, + MaxConcurrency: 2, + RetryDelay: 1 * time.Second, + } + + w := NewWorker(refresher, config, nil) + w.Start() + + // Register multiple tokens that all need refresh + for i := 0; i < 10; i++ { + token := &Token{ + ID: "test-" + string(rune('0'+i)), + Provider: "test", + RefreshToken: "refresh", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + _ = w.Register(token) + } + + // Wait for refreshes + time.Sleep(500 * time.Millisecond) + w.Stop() + + if maxConcurrent.Load() > 2 { + t.Errorf("max concurrent = %d, want <= 2", maxConcurrent.Load()) + } +} + +func TestWorker_StopWaitsForRefreshes(t *testing.T) { + var refreshStarted atomic.Bool + var refreshCompleted atomic.Bool + + refresher := func(ctx context.Context, token *Token) (time.Time, error) { + refreshStarted.Store(true) + time.Sleep(100 * time.Millisecond) + refreshCompleted.Store(true) + return time.Now().Add(1 * time.Hour), nil + } + + config := Config{ + RefreshLeadTime: 30 * time.Minute, + CheckInterval: 10 * time.Millisecond, + MaxConcurrency: 5, + RetryDelay: 1 * time.Second, + } + + w := NewWorker(refresher, config, nil) + w.Start() + + token := &Token{ + ID: "test-stop", + Provider: "test", + RefreshToken: "refresh", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + _ = w.Register(token) + + // Wait for refresh to start + time.Sleep(50 * time.Millisecond) + + // Stop should wait for refresh to complete + w.Stop() + + // Note: Due to timing, the refresh might not have started yet + // So we just verify that Stop() returns without hanging +} + +type testHook struct { + onSuccess func(*Token, time.Time) + onError func(*Token, error) +} + +func (h *testHook) OnRefreshSuccess(token *Token, newExpiresAt time.Time) { + if h.onSuccess != nil { + h.onSuccess(token, newExpiresAt) + } +} + +func (h *testHook) OnRefreshError(token *Token, err error) { + if h.onError != nil { + h.onError(token, err) + } +} diff --git a/internal/config/generic_auth.go b/internal/config/generic_auth.go new file mode 100644 index 000000000..75e882b54 --- /dev/null +++ b/internal/config/generic_auth.go @@ -0,0 +1,24 @@ +package config + +// GenericAuth configures validation of bearer tokens against a generic OAuth2 +// introspection or userinfo endpoint. +// +// This is used by internal/auth/generic helpers and is intentionally small; most +// deployments will configure it via the access-provider system. +type GenericAuth struct { + // ProviderID is the provider key associated with this auth entry (e.g. "generic"). + ProviderID string `yaml:"provider-id,omitempty" json:"provider_id,omitempty"` + + // IntrospectionURL is the token introspection/userinfo endpoint URL. + IntrospectionURL string `yaml:"introspection-url,omitempty" json:"introspection_url,omitempty"` + + // ClientID and ClientSecret are optional client credentials for RFC 7662 introspection endpoints. + ClientID string `yaml:"client-id,omitempty" json:"client_id,omitempty"` + ClientSecret string `yaml:"client-secret,omitempty" json:"client_secret,omitempty"` + + // UserIDField is the JSON field to use as the stable user identifier (default: "sub"). + UserIDField string `yaml:"user-id-field,omitempty" json:"user_id_field,omitempty"` + + // EmailField is the JSON field to use as the user email (default: "email"). + EmailField string `yaml:"email-field,omitempty" json:"email_field,omitempty"` +} diff --git a/internal/oauthflow/auth_code_flow.go b/internal/oauthflow/auth_code_flow.go new file mode 100644 index 000000000..9c926ada8 --- /dev/null +++ b/internal/oauthflow/auth_code_flow.go @@ -0,0 +1,167 @@ +package oauthflow + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strings" + "syscall" + "time" +) + +// AuthCodeFlowOptions controls loopback-based OAuth authorization code flows. +type AuthCodeFlowOptions struct { + DesiredPort int + CallbackPath string + Timeout time.Duration + SkipStateCheck bool + + // OnAuthURL is called after the callback server is started and the provider auth URL is built. + // Callers typically open the browser and/or print instructions here. + OnAuthURL func(authURL string, callbackPort int, redirectURI string) +} + +// AuthCodeFlowResult captures the output of an authorization code flow. +type AuthCodeFlowResult struct { + AuthURL string + RedirectURI string + CallbackPort int + Session OAuthSession + Token *TokenResult + CallbackError string +} + +// GenerateState returns a cryptographically secure random state string. +func GenerateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// GeneratePKCE returns RFC 7636 PKCE verifier/challenge values. +func GeneratePKCE() (verifier, challenge string, err error) { + // 96 random bytes -> 128 base64url chars without padding (same as existing provider implementations). + b := make([]byte, 96) + if _, err := rand.Read(b); err != nil { + return "", "", err + } + verifier = base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b) + hash := sha256.Sum256([]byte(verifier)) + challenge = base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:]) + return verifier, challenge, nil +} + +// RunAuthCodeFlow runs a loopback OAuth authorization code flow (RFC 8252). +// It starts a loopback-only callback server, builds the provider authorization URL, waits for the callback, +// validates state, then exchanges the code for tokens. +func RunAuthCodeFlow(ctx context.Context, provider ProviderOAuth, opts AuthCodeFlowOptions) (*AuthCodeFlowResult, error) { + if provider == nil { + return nil, fmt.Errorf("oauthflow: provider is nil") + } + if ctx == nil { + ctx = context.Background() + } + if opts.Timeout <= 0 { + opts.Timeout = 5 * time.Minute + } + + desiredPort := opts.DesiredPort + server := NewLoopbackServer(desiredPort, opts.CallbackPath) + if err := server.Start(); err != nil { + // Port in use: fall back to port 0 when a non-zero port was requested. + if errors.Is(err, syscall.EADDRINUSE) && desiredPort != 0 { + server = NewLoopbackServer(0, opts.CallbackPath) + if err2 := server.Start(); err2 != nil { + return nil, &FlowError{Kind: FlowErrorKindPortInUse, Err: err2} + } + } else if errors.Is(err, syscall.EADDRINUSE) { + return nil, &FlowError{Kind: FlowErrorKindPortInUse, Err: err} + } else { + return nil, &FlowError{Kind: FlowErrorKindServerStartFailed, Err: err} + } + } + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Stop(stopCtx) + }() + + callbackPort := server.Port() + callbackPath := server.CallbackPath() + redirectURI := fmt.Sprintf("http://localhost:%d%s", callbackPort, callbackPath) + + state, err := GenerateState() + if err != nil { + return nil, fmt.Errorf("oauthflow: state generation failed: %w", err) + } + verifier, challenge, err := GeneratePKCE() + if err != nil { + return nil, fmt.Errorf("oauthflow: pkce generation failed: %w", err) + } + session := OAuthSession{ + State: state, + RedirectURI: redirectURI, + CodeVerifier: verifier, + CodeChallenge: challenge, + } + + authURL, session, err := provider.AuthorizeURL(session) + if err != nil { + return nil, &FlowError{Kind: FlowErrorKindAuthorizeURLFailed, Err: err} + } + + if opts.OnAuthURL != nil { + opts.OnAuthURL(authURL, callbackPort, redirectURI) + } + + cb, err := server.WaitForCallback(ctx, opts.Timeout) + if err != nil { + if errors.Is(err, ErrCallbackTimeout) { + return nil, &FlowError{Kind: FlowErrorKindCallbackTimeout, Err: err} + } + return nil, err + } + + if cb.Error != "" { + return &AuthCodeFlowResult{ + AuthURL: authURL, + RedirectURI: redirectURI, + CallbackPort: callbackPort, + Session: session, + CallbackError: cb.Error, + }, &FlowError{Kind: FlowErrorKindProviderError, Err: fmt.Errorf("%s", cb.Error)} + } + + if !opts.SkipStateCheck && strings.TrimSpace(session.State) != "" && cb.State != session.State { + return nil, &FlowError{Kind: FlowErrorKindInvalidState, Err: fmt.Errorf("state mismatch")} + } + + code := strings.TrimSpace(cb.Code) + if code == "" { + return nil, &FlowError{Kind: FlowErrorKindProviderError, Err: fmt.Errorf("missing authorization code")} + } + + token, err := provider.ExchangeCode(ctx, session, code) + if err != nil { + return nil, &FlowError{Kind: FlowErrorKindCodeExchangeFailed, Err: err} + } + if token != nil && strings.TrimSpace(token.TokenType) == "" { + token.TokenType = http.CanonicalHeaderKey("Bearer") + } + + return &AuthCodeFlowResult{ + AuthURL: authURL, + RedirectURI: redirectURI, + CallbackPort: callbackPort, + Session: session, + Token: token, + }, nil +} + diff --git a/internal/oauthflow/device_flow.go b/internal/oauthflow/device_flow.go new file mode 100644 index 000000000..15762b7c2 --- /dev/null +++ b/internal/oauthflow/device_flow.go @@ -0,0 +1,107 @@ +package oauthflow + +import ( + "context" + "errors" + "fmt" + "net" + "time" +) + +var ( + // ErrAuthorizationPending indicates the user has not completed authorization yet. + ErrAuthorizationPending = errors.New("oauthflow: authorization_pending") + // ErrSlowDown indicates the authorization server asked the client to poll less frequently. + ErrSlowDown = errors.New("oauthflow: slow_down") + // ErrDeviceCodeExpired indicates the device code expired before the user completed authorization. + ErrDeviceCodeExpired = errors.New("oauthflow: device_code_expired") + // ErrAccessDenied indicates the user denied the authorization request. + ErrAccessDenied = errors.New("oauthflow: access_denied") + // ErrPollingTimeout indicates polling exceeded the device code lifetime or an internal timeout. + ErrPollingTimeout = errors.New("oauthflow: polling_timeout") + // ErrTransient indicates a retryable/transient polling failure. + ErrTransient = errors.New("oauthflow: transient") +) + +const ( + defaultDevicePollInterval = 5 * time.Second + minDevicePollInterval = 1 * time.Second + maxDevicePollInterval = 10 * time.Second + maxDevicePollDuration = 15 * time.Minute +) + +// PollDeviceToken runs an RFC 8628-style polling loop. +// +// pollOnce must return: +// - (*TokenResult, nil) on success +// - ErrAuthorizationPending / ErrSlowDown to keep polling +// - ErrDeviceCodeExpired / ErrAccessDenied to abort +// - ErrTransient or a net.Error to keep polling (best-effort) +func PollDeviceToken(ctx context.Context, device *DeviceCodeResult, pollOnce func(context.Context) (*TokenResult, error)) (*TokenResult, error) { + if device == nil { + return nil, fmt.Errorf("oauthflow: device code is nil") + } + if pollOnce == nil { + return nil, fmt.Errorf("oauthflow: pollOnce is nil") + } + if ctx == nil { + ctx = context.Background() + } + + interval := time.Duration(device.Interval) * time.Second + if interval <= 0 { + interval = defaultDevicePollInterval + } + if interval < minDevicePollInterval { + interval = minDevicePollInterval + } + + deadline := time.Now().Add(maxDevicePollDuration) + if device.ExpiresIn > 0 { + expiresAt := time.Now().Add(time.Duration(device.ExpiresIn) * time.Second) + if expiresAt.Before(deadline) { + deadline = expiresAt + } + } + + for { + if err := ctx.Err(); err != nil { + return nil, err + } + if time.Now().After(deadline) { + return nil, ErrPollingTimeout + } + + token, err := pollOnce(ctx) + if err == nil { + return token, nil + } + + switch { + case errors.Is(err, ErrAuthorizationPending): + // keep interval unchanged + case errors.Is(err, ErrSlowDown): + interval += 5 * time.Second + if interval > maxDevicePollInterval { + interval = maxDevicePollInterval + } + case errors.Is(err, ErrDeviceCodeExpired), errors.Is(err, ErrAccessDenied): + return nil, err + default: + // Best-effort: keep polling on transient transport failures. + var ne net.Error + if errors.Is(err, ErrTransient) || errors.As(err, &ne) { + // keep polling + } else { + return nil, err + } + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + } +} + diff --git a/internal/oauthflow/errors.go b/internal/oauthflow/errors.go new file mode 100644 index 000000000..fc337df08 --- /dev/null +++ b/internal/oauthflow/errors.go @@ -0,0 +1,41 @@ +package oauthflow + +import ( + "fmt" +) + +// FlowErrorKind categorizes failures in OAuth flows so callers can map them to provider-specific errors. +type FlowErrorKind string + +const ( + FlowErrorKindPortInUse FlowErrorKind = "port_in_use" + FlowErrorKindServerStartFailed FlowErrorKind = "server_start_failed" + FlowErrorKindAuthorizeURLFailed FlowErrorKind = "authorize_url_failed" + FlowErrorKindCallbackTimeout FlowErrorKind = "callback_timeout" + FlowErrorKindProviderError FlowErrorKind = "provider_error" + FlowErrorKindInvalidState FlowErrorKind = "invalid_state" + FlowErrorKindCodeExchangeFailed FlowErrorKind = "code_exchange_failed" +) + +// FlowError wraps an underlying error with a stable kind for callers to inspect. +type FlowError struct { + Kind FlowErrorKind + Err error +} + +func (e *FlowError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("oauthflow: %s", e.Kind) + } + return fmt.Sprintf("oauthflow: %s: %v", e.Kind, e.Err) +} + +func (e *FlowError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} diff --git a/internal/oauthflow/interfaces.go b/internal/oauthflow/interfaces.go new file mode 100644 index 000000000..855a1b428 --- /dev/null +++ b/internal/oauthflow/interfaces.go @@ -0,0 +1,63 @@ +package oauthflow + +import ( + "context" + "errors" +) + +// ErrRefreshNotSupported is returned when a provider does not support token refresh. +var ErrRefreshNotSupported = errors.New("oauthflow: refresh not supported") + +// ErrRevokeNotSupported is returned when a provider does not support token revocation. +var ErrRevokeNotSupported = errors.New("oauthflow: revoke not supported") + +// TokenResult is a provider-agnostic OAuth token payload. +type TokenResult struct { + AccessToken string + RefreshToken string + ExpiresAt string // RFC3339 when available + TokenType string + IDToken string + Metadata map[string]any +} + +// OAuthSession contains state and PKCE values for browser-based OAuth flows. +type OAuthSession struct { + State string + RedirectURI string + CodeVerifier string + CodeChallenge string +} + +// DeviceCodeResult captures a device-code authorization response. +type DeviceCodeResult struct { + DeviceCode string + UserCode string + VerificationURI string + VerificationURIComplete string + ExpiresIn int + Interval int + CodeVerifier string +} + +// ProviderOAuth describes a browser-based (authorization-code) OAuth provider. +type ProviderOAuth interface { + Provider() string + AuthorizeURL(session OAuthSession) (authURL string, updated OAuthSession, err error) + ExchangeCode(ctx context.Context, session OAuthSession, code string) (*TokenResult, error) + Refresh(ctx context.Context, refreshToken string) (*TokenResult, error) + // Revoke invalidates the given token (access or refresh) at the provider. + // Returns ErrRevokeNotSupported if the provider does not support revocation. + Revoke(ctx context.Context, token string) error +} + +// ProviderDeviceOAuth describes a device-code OAuth provider. +type ProviderDeviceOAuth interface { + Provider() string + DeviceAuthorize(ctx context.Context) (*DeviceCodeResult, error) + DevicePoll(ctx context.Context, device *DeviceCodeResult) (*TokenResult, error) + Refresh(ctx context.Context, refreshToken string) (*TokenResult, error) + // Revoke invalidates the given token (access or refresh) at the provider. + // Returns ErrRevokeNotSupported if the provider does not support revocation. + Revoke(ctx context.Context, token string) error +} diff --git a/internal/oauthflow/loopback_server.go b/internal/oauthflow/loopback_server.go new file mode 100644 index 000000000..7434be27b --- /dev/null +++ b/internal/oauthflow/loopback_server.go @@ -0,0 +1,202 @@ +package oauthflow + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "strings" + "sync" + "syscall" + "time" +) + +// ErrCallbackTimeout is returned when the local loopback server does not receive a callback in time. +var ErrCallbackTimeout = errors.New("oauthflow: callback timeout") + +// CallbackResult contains query parameters returned to the redirect URI. +type CallbackResult struct { + Code string + State string + Error string +} + +// LoopbackServer is a loopback-only HTTP server for OAuth native-app callbacks. +// It binds to 127.0.0.1 and captures a single callback result. +type LoopbackServer struct { + server *http.Server + listener net.Listener + port int + callbackPath string + + resultChan chan CallbackResult + errorChan chan error + + mu sync.Mutex + running bool +} + +// NewLoopbackServer creates a new loopback callback server. +// callbackPath must be an absolute path (e.g., "/callback"). +func NewLoopbackServer(port int, callbackPath string) *LoopbackServer { + path := strings.TrimSpace(callbackPath) + if path == "" { + path = "/callback" + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return &LoopbackServer{ + port: port, + callbackPath: path, + resultChan: make(chan CallbackResult, 1), + errorChan: make(chan error, 1), + } +} + +// CallbackPath returns the server callback path. +func (s *LoopbackServer) CallbackPath() string { + if s == nil { + return "" + } + s.mu.Lock() + defer s.mu.Unlock() + return s.callbackPath +} + +// Port returns the actual bound port once the server has started. +func (s *LoopbackServer) Port() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.port +} + +// Start binds and serves the callback endpoint. +func (s *LoopbackServer) Start() error { + if s == nil { + return fmt.Errorf("oauthflow: server is nil") + } + s.mu.Lock() + defer s.mu.Unlock() + if s.running { + return fmt.Errorf("oauthflow: server already running") + } + + mux := http.NewServeMux() + mux.HandleFunc(s.callbackPath, s.handleCallback) + + addr := fmt.Sprintf("127.0.0.1:%d", s.port) + ln, err := net.Listen("tcp", addr) + if err != nil { + if errors.Is(err, syscall.EADDRINUSE) { + return err + } + return err + } + + s.server = &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + s.listener = ln + if tcp, ok := ln.Addr().(*net.TCPAddr); ok { + s.port = tcp.Port + } + s.running = true + + go func() { + if err := s.server.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + select { + case s.errorChan <- err: + default: + } + } + }() + + return nil +} + +// Stop gracefully shuts down the callback server. +func (s *LoopbackServer) Stop(ctx context.Context) error { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if !s.running || s.server == nil { + return nil + } + + if ctx == nil { + ctx = context.Background() + } + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + err := s.server.Shutdown(shutdownCtx) + if s.listener != nil { + _ = s.listener.Close() + s.listener = nil + } + s.running = false + s.server = nil + return err +} + +// WaitForCallback blocks until a callback result, server error, timeout, or context cancellation. +func (s *LoopbackServer) WaitForCallback(ctx context.Context, timeout time.Duration) (CallbackResult, error) { + if s == nil { + return CallbackResult{}, fmt.Errorf("oauthflow: server is nil") + } + if timeout <= 0 { + timeout = 5 * time.Minute + } + if ctx == nil { + ctx = context.Background() + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case res := <-s.resultChan: + return res, nil + case err := <-s.errorChan: + return CallbackResult{}, err + case <-ctx.Done(): + return CallbackResult{}, ctx.Err() + case <-timer.C: + return CallbackResult{}, ErrCallbackTimeout + } +} + +func (s *LoopbackServer) handleCallback(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + query := r.URL.Query() + res := CallbackResult{ + Code: strings.TrimSpace(query.Get("code")), + State: strings.TrimSpace(query.Get("state")), + Error: strings.TrimSpace(query.Get("error")), + } + + select { + case s.resultChan <- res: + default: + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if res.Error != "" || res.Code == "" { + _, _ = w.Write([]byte("

Login failed

Please check the CLI output.

You can close this window.

")) + return + } + _, _ = w.Write([]byte("

Login successful

You can close this window.

")) +} diff --git a/internal/oauthhttp/oauthhttp.go b/internal/oauthhttp/oauthhttp.go new file mode 100644 index 000000000..958cfbaf7 --- /dev/null +++ b/internal/oauthhttp/oauthhttp.go @@ -0,0 +1,211 @@ +package oauthhttp + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +// ErrBodyTooLarge is returned when the response body exceeds MaxBodyBytes. +var ErrBodyTooLarge = errors.New("oauthhttp: response body too large") + +// RetryConfig controls retry/backoff behavior for OAuth HTTP calls. +type RetryConfig struct { + // MaxAttempts is the total number of attempts (including the first). + MaxAttempts int + // InitialBackoff is the delay before the first retry. + InitialBackoff time.Duration + // BackoffFactor is the multiplier applied after each attempt. + BackoffFactor float64 + // MaxBackoff caps exponential backoff. + MaxBackoff time.Duration + // RetryOnStatus controls which HTTP status codes are retried. + RetryOnStatus map[int]struct{} + // MaxBodyBytes caps how much of the response body is read into memory (0 = 1MB default). + MaxBodyBytes int64 +} + +// DefaultRetryConfig uses conservative OAuth HTTP defaults: +// max_attempts=3, backoff=0.5s, factor=2, max_backoff=10s, retry on 429 + common 5xx. +func DefaultRetryConfig() RetryConfig { + return RetryConfig{ + MaxAttempts: 3, + InitialBackoff: 500 * time.Millisecond, + BackoffFactor: 2.0, + MaxBackoff: 10 * time.Second, + RetryOnStatus: map[int]struct{}{ + http.StatusTooManyRequests: {}, + http.StatusInternalServerError: {}, + http.StatusBadGateway: {}, + http.StatusServiceUnavailable: {}, + http.StatusGatewayTimeout: {}, + }, + MaxBodyBytes: 1 << 20, // 1MB + } +} + +// Do executes buildReq+client.Do with retry/backoff for transient OAuth failures. +// +// Returns: +// - status: HTTP status code (0 when no HTTP response was received) +// - headers: response headers (nil when no HTTP response was received) +// - body: response body (may be partial if ErrBodyTooLarge) +// - err: network/build errors, ErrBodyTooLarge when body exceeds limit, or a retry exhaustion error when the final +// response status was retryable (e.g. repeated 503s) +func Do( + ctx context.Context, + client *http.Client, + buildReq func() (*http.Request, error), + cfg RetryConfig, +) (status int, headers http.Header, body []byte, err error) { + if client == nil { + return 0, nil, nil, fmt.Errorf("oauthhttp: http client is nil") + } + if buildReq == nil { + return 0, nil, nil, fmt.Errorf("oauthhttp: buildReq is nil") + } + if cfg.MaxAttempts <= 0 { + cfg.MaxAttempts = DefaultRetryConfig().MaxAttempts + } + if cfg.InitialBackoff <= 0 { + cfg.InitialBackoff = DefaultRetryConfig().InitialBackoff + } + if cfg.BackoffFactor <= 0 { + cfg.BackoffFactor = DefaultRetryConfig().BackoffFactor + } + if cfg.MaxBackoff <= 0 { + cfg.MaxBackoff = DefaultRetryConfig().MaxBackoff + } + if cfg.RetryOnStatus == nil { + cfg.RetryOnStatus = DefaultRetryConfig().RetryOnStatus + } + if cfg.MaxBodyBytes <= 0 { + cfg.MaxBodyBytes = DefaultRetryConfig().MaxBodyBytes + } + if ctx == nil { + ctx = context.Background() + } + + backoff := cfg.InitialBackoff + + for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ { + req, errBuild := buildReq() + if errBuild != nil { + return 0, nil, nil, errBuild + } + if req == nil { + return 0, nil, nil, fmt.Errorf("oauthhttp: buildReq returned nil request") + } + if req.Context() == nil { + req = req.WithContext(ctx) + } + + resp, errDo := client.Do(req) + if errDo != nil { + // Network/transport error. + if attempt >= cfg.MaxAttempts { + return 0, nil, nil, errDo + } + if errWait := wait(ctx, backoff); errWait != nil { + return 0, nil, nil, errWait + } + backoff = nextBackoff(backoff, cfg.BackoffFactor, cfg.MaxBackoff) + continue + } + + status = resp.StatusCode + headers = resp.Header.Clone() + body, err = readLimitedAndClose(resp.Body, cfg.MaxBodyBytes) + if err != nil { + // Body read error or size cap. Don't retry; surface for visibility. + return status, headers, body, err + } + + if _, retryable := cfg.RetryOnStatus[status]; !retryable { + return status, headers, body, nil + } + if attempt >= cfg.MaxAttempts { + return status, headers, body, fmt.Errorf("oauthhttp: request failed with status %d after %d attempts", status, cfg.MaxAttempts) + } + + delay := backoff + if ra := retryAfter(headers.Get("Retry-After")); ra > delay { + delay = ra + } + if errWait := wait(ctx, delay); errWait != nil { + return status, headers, body, errWait + } + backoff = nextBackoff(backoff, cfg.BackoffFactor, cfg.MaxBackoff) + } + + return status, headers, body, nil +} + +func nextBackoff(current time.Duration, factor float64, max time.Duration) time.Duration { + if current <= 0 { + current = 500 * time.Millisecond + } + if factor <= 0 { + factor = 2.0 + } + next := time.Duration(float64(current) * factor) + if max > 0 && next > max { + return max + } + return next +} + +func wait(ctx context.Context, d time.Duration) error { + if d <= 0 { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(d): + return nil + } +} + +func readLimitedAndClose(rc io.ReadCloser, maxBytes int64) ([]byte, error) { + if rc == nil { + return nil, nil + } + defer func() { _ = rc.Close() }() + + limit := maxBytes + if limit <= 0 { + limit = 1 << 20 + } + r := io.LimitReader(rc, limit+1) + data, err := io.ReadAll(r) + if err != nil { + return data, err + } + if int64(len(data)) > limit { + return data[:limit], ErrBodyTooLarge + } + return data, nil +} + +func retryAfter(raw string) time.Duration { + s := strings.TrimSpace(raw) + if s == "" { + return 0 + } + if seconds, err := strconv.ParseFloat(s, 64); err == nil && seconds > 0 { + return time.Duration(seconds * float64(time.Second)) + } + if t, err := http.ParseTime(s); err == nil { + d := time.Until(t) + if d > 0 { + return d + } + } + return 0 +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 2b4ec7482..813b2db6a 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -849,35 +849,40 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c now := time.Now().Unix() modelConfig := registry.GetAntigravityModelConfig() models := make([]*registry.ModelInfo, 0, len(result.Map())) + buildModelInfo := func(id string, cfg *registry.AntigravityModelConfig, useCfgName bool) *registry.ModelInfo { + modelName := id + if useCfgName && cfg != nil && cfg.Name != "" { + modelName = cfg.Name + } + modelInfo := ®istry.ModelInfo{ + ID: id, + Name: modelName, + Description: id, + DisplayName: id, + Version: id, + Object: "model", + Created: now, + OwnedBy: antigravityAuthType, + Type: antigravityAuthType, + } + if cfg != nil { + if cfg.Thinking != nil { + modelInfo.Thinking = cfg.Thinking + } + if cfg.MaxCompletionTokens > 0 { + modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens + } + } + return modelInfo + } for originalName := range result.Map() { aliasName := modelName2Alias(originalName) if aliasName != "" { cfg := modelConfig[aliasName] - modelName := aliasName - if cfg != nil && cfg.Name != "" { - modelName = cfg.Name + models = append(models, buildModelInfo(aliasName, cfg, true)) + if shouldExposeAntigravityOriginalName(originalName, aliasName) { + models = append(models, buildModelInfo(originalName, cfg, false)) } - modelInfo := ®istry.ModelInfo{ - ID: aliasName, - Name: modelName, - Description: aliasName, - DisplayName: aliasName, - Version: aliasName, - Object: "model", - Created: now, - OwnedBy: antigravityAuthType, - Type: antigravityAuthType, - } - // Look up Thinking support from static config using alias name - if cfg != nil { - if cfg.Thinking != nil { - modelInfo.Thinking = cfg.Thinking - } - if cfg.MaxCompletionTokens > 0 { - modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens - } - } - models = append(models, modelInfo) } } return models @@ -1033,6 +1038,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau } else { httpReq.Header.Set("Accept", "application/json") } + if util.IsClaudeThinkingModel(modelName) { + httpReq.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14") + } if host := resolveHost(base); host != "" { httpReq.Host = host } @@ -1278,13 +1286,18 @@ func modelName2Alias(modelName string) string { return "gemini-claude-sonnet-4-5-thinking" case "claude-opus-4-5-thinking": return "gemini-claude-opus-4-5-thinking" - case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro": - return "" default: return modelName } } +func shouldExposeAntigravityOriginalName(originalName, aliasName string) bool { + if originalName == "" || aliasName == "" || originalName == aliasName { + return false + } + return true +} + func alias2ModelName(modelName string) string { switch modelName { case "gemini-2.5-computer-use-preview-10-2025": diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 875e54a71..cead5d186 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -108,7 +108,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq } // Override default values with actual response metadata if available from the Gemini CLI response - if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) } if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { @@ -371,9 +377,6 @@ func resolveStopReason(params *Params) string { // Returns: // - string: A Claude-compatible JSON response. func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON - _ = requestRawJSON - root := gjson.ParseBytes(rawJSON) promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int() candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() @@ -390,7 +393,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String()) - responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String()) + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + responseJSON, _ = sjson.Set(responseJSON, "model", requestedModel) + } else { + responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String()) + } responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens) responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens) // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index 1b7866d01..9228c7bb4 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -59,7 +59,13 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` // Extract and set the model version. - if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + template, _ = sjson.Set(template, "model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { template, _ = sjson.Set(template, "model", modelVersionResult.String()) } diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_response.go b/internal/translator/codex/openai/chat-completions/codex_openai_response.go index 6d86c247a..584051f27 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_response.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_response.go @@ -8,6 +8,7 @@ package chat_completions import ( "bytes" "context" + "strings" "time" "github.com/tidwall/gjson" @@ -70,7 +71,13 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR } // Extract and set the model version. - if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + template, _ = sjson.Set(template, "model", requestedModel) + } else if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() { template, _ = sjson.Set(template, "model", modelResult.String()) } @@ -178,7 +185,13 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` // Extract and set the model version. - if modelResult := responseResult.Get("model"); modelResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + template, _ = sjson.Set(template, "model", requestedModel) + } else if modelResult := responseResult.Get("model"); modelResult.Exists() { template, _ = sjson.Set(template, "model", modelResult.String()) } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 2f8e95488..ff524ae03 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -80,7 +80,13 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` // Override default values with actual response metadata if available from the Gemini CLI response - if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) } if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { @@ -270,14 +276,19 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Returns: // - string: A Claude-compatible JSON response. func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON - _ = requestRawJSON - root := gjson.ParseBytes(rawJSON) out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out, _ = sjson.Set(out, "id", root.Get("response.responseId").String()) - out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String()) + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + out, _ = sjson.Set(out, "model", requestedModel) + } else { + out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String()) + } inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int() outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int() diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go index 5a1faf510..1aa9ff035 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go @@ -57,7 +57,13 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` // Extract and set the model version. - if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + template, _ = sjson.Set(template, "model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { template, _ = sjson.Set(template, "model", modelVersionResult.String()) } diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index db14c78a1..4eaf9a7ed 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -80,7 +80,13 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` // Override default values with actual response metadata if available - if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() { messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) } if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() { @@ -276,14 +282,19 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Returns: // - string: A Claude-compatible JSON response. func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON - _ = requestRawJSON - root := gjson.ParseBytes(rawJSON) out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out, _ = sjson.Set(out, "id", root.Get("responseId").String()) - out, _ = sjson.Set(out, "model", root.Get("modelVersion").String()) + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + out, _ = sjson.Set(out, "model", requestedModel) + } else { + out, _ = sjson.Set(out, "model", root.Get("modelVersion").String()) + } inputTokens := root.Get("usageMetadata.promptTokenCount").Int() outputTokens := root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int() diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index 52fbba430..0a8f70335 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -61,7 +61,13 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` // Extract and set the model version. - if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + template, _ = sjson.Set(template, "model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() { template, _ = sjson.Set(template, "model", modelVersionResult.String()) } @@ -220,7 +226,13 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { var unixTimestamp int64 template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` - if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() { + requestedModel := strings.TrimSpace(gjson.GetBytes(originalRequestRawJSON, "model").String()) + if requestedModel == "" { + requestedModel = strings.TrimSpace(gjson.GetBytes(requestRawJSON, "model").String()) + } + if requestedModel != "" { + template, _ = sjson.Set(template, "model", requestedModel) + } else if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() { template, _ = sjson.Set(template, "model", modelVersionResult.String()) } diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index 2daf0a79b..b02ec9a37 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -12,6 +12,29 @@ import ( var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?") +// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API. +// It handles unsupported keywords, type flattening, and schema simplification while preserving +// semantic information as description hints. +func CleanJSONSchemaForGemini(jsonStr string) string { + // Phase 1: Convert and add hints + jsonStr = convertRefsToHints(jsonStr) + jsonStr = convertConstToEnum(jsonStr) + jsonStr = addEnumHints(jsonStr) + jsonStr = addAdditionalPropertiesHints(jsonStr) + jsonStr = moveConstraintsToDescription(jsonStr) + + // Phase 2: Flatten complex structures + jsonStr = mergeAllOf(jsonStr) + jsonStr = flattenAnyOfOneOf(jsonStr) + jsonStr = flattenTypeArrays(jsonStr) + + // Phase 3: Cleanup + jsonStr = removeUnsupportedKeywords(jsonStr) + jsonStr = cleanupRequiredFields(jsonStr) + + return jsonStr +} + // CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API. // It handles unsupported keywords, type flattening, and schema simplification while preserving // semantic information as description hints. diff --git a/internal/util/proxy.go b/internal/util/proxy.go index aea52ba8c..f48bcb26f 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -8,12 +8,89 @@ import ( "net" "net/http" "net/url" + "os" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" log "github.com/sirupsen/logrus" "golang.org/x/net/proxy" ) +var oauthProxyEnvVars = []string{ + "CLIPROXY_OAUTH_PROXY", + "CLI_PROXY_API_OAUTH_PROXY", +} + +// ResolveOAuthProxyURL returns the proxy URL to use for OAuth requests. +// +// Priority: +// 1. CLIPROXY_OAUTH_PROXY / CLI_PROXY_API_OAUTH_PROXY env vars +// 2. cfg.ProxyURL +// 3. empty string (net/http defaults apply, including HTTP(S)_PROXY env vars) +func ResolveOAuthProxyURL(cfg *config.SDKConfig) string { + for _, key := range oauthProxyEnvVars { + if v, ok := os.LookupEnv(key); ok { + if trimmed := strings.TrimSpace(v); trimmed != "" { + return trimmed + } + } + } + if cfg == nil { + return "" + } + return strings.TrimSpace(cfg.ProxyURL) +} + +// SetOAuthProxy configures the provided HTTP client with proxy settings for OAuth flows. +// It mirrors SetProxy but allows an OAuth-specific override via CLIPROXY_OAUTH_PROXY. +func SetOAuthProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { + if httpClient == nil { + httpClient = &http.Client{} + } + + proxyURLRaw := ResolveOAuthProxyURL(cfg) + if proxyURLRaw == "" { + return httpClient + } + + var transport *http.Transport + + // Attempt to parse the proxy URL from the configuration/env. + proxyURL, errParse := url.Parse(proxyURLRaw) + if errParse == nil { + // Handle different proxy schemes. + if proxyURL.Scheme == "socks5" { + // Configure SOCKS5 proxy with optional authentication. + var proxyAuth *proxy.Auth + if proxyURL.User != nil { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + proxyAuth = &proxy.Auth{User: username, Password: password} + } + dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) + return httpClient + } + // Set up a custom transport using the SOCKS5 dialer. + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + } + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + // Configure HTTP or HTTPS proxy. + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + } + } + + // If a new transport was created, apply it to the HTTP client. + if transport != nil { + httpClient.Transport = transport + } + return httpClient +} + // SetProxy configures the provided HTTP client with proxy settings from the configuration. // It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport // to route requests through the configured proxy server. diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index ae22f7725..a4fdbbaac 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -3,9 +3,8 @@ package auth import ( "context" "encoding/json" + "errors" "fmt" - "io" - "net" "net/http" "net/url" "strings" @@ -13,7 +12,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthhttp" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" @@ -60,121 +60,79 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o opts = &LoginOptions{} } - httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) + httpClient := util.SetOAuthProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}) - state, err := misc.GenerateRandomState() - if err != nil { - return nil, fmt.Errorf("antigravity: failed to generate state: %w", err) - } - - srv, port, cbChan, errServer := startAntigravityCallbackServer() - if errServer != nil { - return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer) - } - defer func() { - shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - _ = srv.Shutdown(shutdownCtx) - }() - - redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port) - authURL := buildAntigravityAuthURL(redirectURI, state) - - if !opts.NoBrowser { - fmt.Println("Opening browser for antigravity authentication") - if !browser.IsAvailable() { - log.Warn("No browser available; please open the URL manually") - util.PrintSSHTunnelInstructions(port) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } else if errOpen := browser.OpenURL(authURL); errOpen != nil { - log.Warnf("Failed to open browser automatically: %v", errOpen) - util.PrintSSHTunnelInstructions(port) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - } else { - util.PrintSSHTunnelInstructions(port) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - - fmt.Println("Waiting for antigravity authentication callback...") - - var cbRes callbackResult - timeoutTimer := time.NewTimer(5 * time.Minute) - defer timeoutTimer.Stop() - - var manualPromptTimer *time.Timer - var manualPromptC <-chan time.Time - if opts.Prompt != nil { - manualPromptTimer = time.NewTimer(15 * time.Second) - manualPromptC = manualPromptTimer.C - defer manualPromptTimer.Stop() - } - -waitForCallback: - for { - select { - case res := <-cbChan: - cbRes = res - break waitForCallback - case <-manualPromptC: - manualPromptC = nil - if manualPromptTimer != nil { - manualPromptTimer.Stop() - } - select { - case res := <-cbChan: - cbRes = res - break waitForCallback - default: - } - input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ") - if errPrompt != nil { - return nil, errPrompt - } - parsed, errParse := misc.ParseOAuthCallback(input) - if errParse != nil { - return nil, errParse + desiredPort := antigravityCallbackPort + provider := newAntigravityOAuthProvider(httpClient) + + flow, err := oauthflow.RunAuthCodeFlow(ctx, provider, oauthflow.AuthCodeFlowOptions{ + DesiredPort: desiredPort, + CallbackPath: "/oauth-callback", + Timeout: 5 * time.Minute, + OnAuthURL: func(authURL string, callbackPort int, redirectURI string) { + if desiredPort != 0 && callbackPort != desiredPort { + log.Warnf("antigravity oauth callback port %d is busy; falling back to an ephemeral port", desiredPort) } - if parsed == nil { - continue + + if !opts.NoBrowser { + fmt.Println("Opening browser for antigravity authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if errOpen := browser.OpenURL(authURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) } - cbRes = callbackResult{ - Code: parsed.Code, - State: parsed.State, - Error: parsed.Error, + + fmt.Println("Waiting for antigravity authentication callback...") + }, + }) + if err != nil { + var flowErr *oauthflow.FlowError + if errors.As(err, &flowErr) && flowErr != nil { + switch flowErr.Kind { + case oauthflow.FlowErrorKindPortInUse: + return nil, fmt.Errorf("antigravity auth callback port in use: %w", err) + case oauthflow.FlowErrorKindServerStartFailed: + return nil, fmt.Errorf("antigravity auth callback server failed: %w", err) + case oauthflow.FlowErrorKindCallbackTimeout: + return nil, fmt.Errorf("antigravity auth: callback wait failed: %w", err) + case oauthflow.FlowErrorKindProviderError: + if flow != nil && flow.CallbackError != "" { + return nil, fmt.Errorf("antigravity auth: provider returned error %s", flow.CallbackError) + } + return nil, fmt.Errorf("antigravity auth: provider returned error") + case oauthflow.FlowErrorKindInvalidState: + return nil, fmt.Errorf("antigravity auth: state mismatch") + case oauthflow.FlowErrorKindCodeExchangeFailed: + return nil, fmt.Errorf("antigravity token exchange failed: %w", flowErr.Err) } - break waitForCallback - case <-timeoutTimer.C: - return nil, fmt.Errorf("antigravity: authentication timed out") } + return nil, err } - - if cbRes.Error != "" { - return nil, fmt.Errorf("antigravity: authentication failed: %s", cbRes.Error) - } - if cbRes.State != state { - return nil, fmt.Errorf("antigravity: invalid state") - } - if cbRes.Code == "" { - return nil, fmt.Errorf("antigravity: missing authorization code") + if flow == nil || flow.Token == nil { + return nil, fmt.Errorf("antigravity authentication failed: missing token result") } - tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient) - if errToken != nil { - return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken) - } + token := flow.Token email := "" - if tokenResp.AccessToken != "" { - if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" { + if token.AccessToken != "" { + if info, errInfo := fetchAntigravityUserInfo(ctx, token.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" { email = strings.TrimSpace(info.Email) } } // Fetch project ID via loadCodeAssist (same approach as Gemini CLI) projectID := "" - if tokenResp.AccessToken != "" { - fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient) + if token.AccessToken != "" { + fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, token.AccessToken, httpClient) if errProject != nil { log.Warnf("antigravity: failed to fetch project ID: %v", errProject) } else { @@ -184,13 +142,28 @@ waitForCallback: } now := time.Now() + expiresIn := int64(0) + if token.Metadata != nil { + switch v := token.Metadata["expires_in"].(type) { + case int: + expiresIn = int64(v) + case int64: + expiresIn = v + case float64: + expiresIn = int64(v) + } + } + expiredAt := strings.TrimSpace(token.ExpiresAt) + if expiredAt == "" && expiresIn > 0 { + expiredAt = now.Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + } metadata := map[string]any{ "type": "antigravity", - "access_token": tokenResp.AccessToken, - "refresh_token": tokenResp.RefreshToken, - "expires_in": tokenResp.ExpiresIn, + "access_token": token.AccessToken, + "refresh_token": token.RefreshToken, + "expires_in": expiresIn, "timestamp": now.UnixMilli(), - "expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), + "expired": expiredAt, } if email != "" { metadata["email"] = email @@ -218,45 +191,107 @@ waitForCallback: }, nil } -type callbackResult struct { - Code string - Error string - State string +type antigravityOAuthProvider struct { + httpClient *http.Client +} + +func newAntigravityOAuthProvider(httpClient *http.Client) *antigravityOAuthProvider { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + return &antigravityOAuthProvider{httpClient: httpClient} +} + +func (p *antigravityOAuthProvider) Provider() string { + return "antigravity" } -func startAntigravityCallbackServer() (*http.Server, int, <-chan callbackResult, error) { - addr := fmt.Sprintf(":%d", antigravityCallbackPort) - listener, err := net.Listen("tcp", addr) +func (p *antigravityOAuthProvider) AuthorizeURL(session oauthflow.OAuthSession) (string, oauthflow.OAuthSession, error) { + if p == nil { + return "", session, fmt.Errorf("antigravity oauth provider: provider is nil") + } + redirectURI := strings.TrimSpace(session.RedirectURI) + if redirectURI == "" { + return "", session, fmt.Errorf("antigravity oauth provider: redirect URI is empty") + } + authURL := buildAntigravityAuthURL(redirectURI, session.State, session.CodeChallenge) + return authURL, session, nil +} + +func (p *antigravityOAuthProvider) ExchangeCode(ctx context.Context, session oauthflow.OAuthSession, code string) (*oauthflow.TokenResult, error) { + if p == nil { + return nil, fmt.Errorf("antigravity oauth provider: provider is nil") + } + tokenResp, err := exchangeAntigravityCode(ctx, code, session.RedirectURI, session.CodeVerifier, p.httpClient) if err != nil { - return nil, 0, nil, err - } - port := listener.Addr().(*net.TCPAddr).Port - resultCh := make(chan callbackResult, 1) - - mux := http.NewServeMux() - mux.HandleFunc("/oauth-callback", func(w http.ResponseWriter, r *http.Request) { - q := r.URL.Query() - res := callbackResult{ - Code: strings.TrimSpace(q.Get("code")), - Error: strings.TrimSpace(q.Get("error")), - State: strings.TrimSpace(q.Get("state")), - } - resultCh <- res - if res.Code != "" && res.Error == "" { - _, _ = w.Write([]byte("

Login successful

You can close this window.

")) - } else { - _, _ = w.Write([]byte("

Login failed

Please check the CLI output.

")) - } - }) + return nil, err + } + if tokenResp == nil { + return nil, fmt.Errorf("antigravity oauth provider: token response is nil") + } - srv := &http.Server{Handler: mux} - go func() { - if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") { - log.Warnf("antigravity callback server error: %v", errServe) - } - }() + tokenType := strings.TrimSpace(tokenResp.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + expiresAt := "" + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339) + } + + meta := map[string]any{ + "expires_in": tokenResp.ExpiresIn, + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(tokenResp.AccessToken), + RefreshToken: strings.TrimSpace(tokenResp.RefreshToken), + ExpiresAt: expiresAt, + TokenType: tokenType, + Metadata: meta, + }, nil +} + +func (p *antigravityOAuthProvider) Refresh(ctx context.Context, refreshToken string) (*oauthflow.TokenResult, error) { + if p == nil { + return nil, fmt.Errorf("antigravity oauth provider: provider is nil") + } + refreshToken = strings.TrimSpace(refreshToken) + if refreshToken == "" { + return nil, fmt.Errorf("antigravity oauth provider: refresh token is empty") + } + tokenResp, err := refreshAntigravityTokens(ctx, refreshToken, p.httpClient) + if err != nil { + return nil, err + } + if tokenResp == nil { + return nil, fmt.Errorf("antigravity oauth provider: refresh response is nil") + } + + tokenType := strings.TrimSpace(tokenResp.TokenType) + if tokenType == "" { + tokenType = "Bearer" + } + expiresAt := "" + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339) + } + + meta := map[string]any{ + "expires_in": tokenResp.ExpiresIn, + } + + return &oauthflow.TokenResult{ + AccessToken: strings.TrimSpace(tokenResp.AccessToken), + RefreshToken: refreshToken, + ExpiresAt: expiresAt, + TokenType: tokenType, + Metadata: meta, + }, nil +} - return srv, port, resultCh, nil +func (p *antigravityOAuthProvider) Revoke(ctx context.Context, token string) error { + return oauthflow.ErrRevokeNotSupported } type antigravityTokenResponse struct { @@ -266,37 +301,103 @@ type antigravityTokenResponse struct { TokenType string `json:"token_type"` } -func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) { +func exchangeAntigravityCode(ctx context.Context, code, redirectURI, codeVerifier string, httpClient *http.Client) (*antigravityTokenResponse, error) { + if ctx == nil { + ctx = context.Background() + } data := url.Values{} data.Set("code", code) data.Set("client_id", antigravityClientID) data.Set("client_secret", antigravityClientSecret) data.Set("redirect_uri", redirectURI) data.Set("grant_type", "authorization_code") - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode())) + if strings.TrimSpace(codeVerifier) != "" { + data.Set("code_verifier", strings.TrimSpace(codeVerifier)) + } + + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + ctx, + httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return nil, err + } + if status < http.StatusOK || status >= http.StatusMultipleChoices { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("oauth token exchange failed: status %d: %s: %w", status, msg, err) + } + return nil, fmt.Errorf("oauth token exchange failed: status %d: %s", status, msg) + } if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, errDo := httpClient.Do(req) - if errDo != nil { - return nil, errDo + var token antigravityTokenResponse + if errDecode := json.Unmarshal(body, &token); errDecode != nil { + return nil, errDecode } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity token exchange: close body error: %v", errClose) + return &token, nil +} + +func refreshAntigravityTokens(ctx context.Context, refreshToken string, httpClient *http.Client) (*antigravityTokenResponse, error) { + if ctx == nil { + ctx = context.Background() + } + refreshToken = strings.TrimSpace(refreshToken) + if refreshToken == "" { + return nil, fmt.Errorf("refresh token is empty") + } + data := url.Values{} + data.Set("refresh_token", refreshToken) + data.Set("client_id", antigravityClientID) + data.Set("client_secret", antigravityClientSecret) + data.Set("grant_type", "refresh_token") + + encoded := data.Encode() + status, _, body, err := oauthhttp.Do( + ctx, + httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(encoded)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return nil, err + } + if status < http.StatusOK || status >= http.StatusMultipleChoices { + msg := strings.TrimSpace(string(body)) + if err != nil { + return nil, fmt.Errorf("oauth token refresh failed: status %d: %s: %w", status, msg, err) } - }() + return nil, fmt.Errorf("oauth token refresh failed: status %d: %s", status, msg) + } + if err != nil { + return nil, err + } var token antigravityTokenResponse - if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil { + if errDecode := json.Unmarshal(body, &token); errDecode != nil { return nil, errDecode } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode) - } return &token, nil } @@ -308,33 +409,40 @@ func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClien if strings.TrimSpace(accessToken) == "" { return &antigravityUserInfo{}, nil } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) - if err != nil { - return nil, err + if ctx == nil { + ctx = context.Background() } - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return nil, errDo + status, _, body, err := oauthhttp.Do( + ctx, + httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return nil, err } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity userinfo: close body error: %v", errClose) - } - }() - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if status < http.StatusOK || status >= http.StatusMultipleChoices { return &antigravityUserInfo{}, nil } + if err != nil { + return nil, err + } var info antigravityUserInfo - if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil { + if errDecode := json.Unmarshal(body, &info); errDecode != nil { return nil, errDecode } return &info, nil } -func buildAntigravityAuthURL(redirectURI, state string) string { +func buildAntigravityAuthURL(redirectURI, state, codeChallenge string) string { params := url.Values{} params.Set("access_type", "offline") params.Set("client_id", antigravityClientID) @@ -343,6 +451,10 @@ func buildAntigravityAuthURL(redirectURI, state string) string { params.Set("response_type", "code") params.Set("scope", strings.Join(antigravityScopes, " ")) params.Set("state", state) + if strings.TrimSpace(codeChallenge) != "" { + params.Set("code_challenge", strings.TrimSpace(codeChallenge)) + params.Set("code_challenge_method", "S256") + } return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() } @@ -371,6 +483,9 @@ func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie // fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist. // This uses the same approach as Gemini CLI to get the cloudaicompanionProject. func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) { + if ctx == nil { + ctx = context.Background() + } // Call loadCodeAssist to get the project loadReqBody := map[string]any{ "metadata": map[string]string{ @@ -386,33 +501,37 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie } endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) - if err != nil { - return "", fmt.Errorf("create request: %w", err) + status, _, bodyBytes, err := oauthhttp.Do( + ctx, + httpClient, + func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) + req.Header.Set("Client-Metadata", antigravityClientMetadata) + return req, nil + }, + oauthhttp.DefaultRetryConfig(), + ) + if err != nil && status == 0 { + return "", fmt.Errorf("execute request: %w", err) } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", antigravityAPIUserAgent) - req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) - req.Header.Set("Client-Metadata", antigravityClientMetadata) - resp, errDo := httpClient.Do(req) - if errDo != nil { - return "", fmt.Errorf("execute request: %w", errDo) - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose) + if status < http.StatusOK || status >= http.StatusMultipleChoices { + msg := strings.TrimSpace(string(bodyBytes)) + if err != nil { + return "", fmt.Errorf("request failed with status %d: %s: %w", status, msg, err) } - }() - - bodyBytes, errRead := io.ReadAll(resp.Body) - if errRead != nil { - return "", fmt.Errorf("read response: %w", errRead) + return "", fmt.Errorf("request failed with status %d: %s", status, msg) } - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) + if err != nil { + return "", fmt.Errorf("execute request: %w", err) } var loadResp map[string]any diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index c43b78cd9..48c697e02 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "fmt" "net/http" "strings" @@ -11,7 +12,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" // legacy client removed "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" @@ -47,155 +48,97 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt opts = &LoginOptions{} } - pkceCodes, err := claude.GeneratePKCECodes() - if err != nil { - return nil, fmt.Errorf("claude pkce generation failed: %w", err) - } - - state, err := misc.GenerateRandomState() - if err != nil { - return nil, fmt.Errorf("claude state generation failed: %w", err) - } - - oauthServer := claude.NewOAuthServer(a.CallbackPort) - if err = oauthServer.Start(); err != nil { - if strings.Contains(err.Error(), "already in use") { - return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err) - } - return nil, claude.NewAuthenticationError(claude.ErrServerStartFailed, err) - } - defer func() { - stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if stopErr := oauthServer.Stop(stopCtx); stopErr != nil { - log.Warnf("claude oauth server stop error: %v", stopErr) - } - }() - + desiredPort := a.CallbackPort authSvc := claude.NewClaudeAuth(cfg) + provider := claude.NewOAuthProvider(authSvc) + + flow, err := oauthflow.RunAuthCodeFlow(ctx, provider, oauthflow.AuthCodeFlowOptions{ + DesiredPort: desiredPort, + CallbackPath: "/callback", + Timeout: 5 * time.Minute, + OnAuthURL: func(authURL string, callbackPort int, redirectURI string) { + if desiredPort != 0 && callbackPort != desiredPort { + log.Warnf("claude oauth callback port %d is busy; falling back to an ephemeral port", desiredPort) + } - authURL, returnedState, err := authSvc.GenerateAuthURL(state, pkceCodes) - if err != nil { - return nil, fmt.Errorf("claude authorization url generation failed: %w", err) - } - state = returnedState - - if !opts.NoBrowser { - fmt.Println("Opening browser for Claude authentication") - if !browser.IsAvailable() { - log.Warn("No browser available; please open the URL manually") - util.PrintSSHTunnelInstructions(a.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } else if err = browser.OpenURL(authURL); err != nil { - log.Warnf("Failed to open browser automatically: %v", err) - util.PrintSSHTunnelInstructions(a.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - } else { - util.PrintSSHTunnelInstructions(a.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - - fmt.Println("Waiting for Claude authentication callback...") - - callbackCh := make(chan *claude.OAuthResult, 1) - callbackErrCh := make(chan error, 1) - manualDescription := "" - - go func() { - result, errWait := oauthServer.WaitForCallback(5 * time.Minute) - if errWait != nil { - callbackErrCh <- errWait - return - } - callbackCh <- result - }() - - var result *claude.OAuthResult - var manualPromptTimer *time.Timer - var manualPromptC <-chan time.Time - if opts.Prompt != nil { - manualPromptTimer = time.NewTimer(15 * time.Second) - manualPromptC = manualPromptTimer.C - defer manualPromptTimer.Stop() - } + if !opts.NoBrowser { + fmt.Println("Opening browser for Claude authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if errOpen := browser.OpenURL(authURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } -waitForCallback: - for { - select { - case result = <-callbackCh: - break waitForCallback - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { + fmt.Println("Waiting for Claude authentication callback...") + }, + }) + if err != nil { + var flowErr *oauthflow.FlowError + if errors.As(err, &flowErr) && flowErr != nil { + switch flowErr.Kind { + case oauthflow.FlowErrorKindPortInUse: + return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err) + case oauthflow.FlowErrorKindServerStartFailed: + return nil, claude.NewAuthenticationError(claude.ErrServerStartFailed, err) + case oauthflow.FlowErrorKindAuthorizeURLFailed: + return nil, fmt.Errorf("claude authorization url generation failed: %w", flowErr.Err) + case oauthflow.FlowErrorKindCallbackTimeout: return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) - } - return nil, err - case <-manualPromptC: - manualPromptC = nil - if manualPromptTimer != nil { - manualPromptTimer.Stop() - } - select { - case result = <-callbackCh: - break waitForCallback - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + case oauthflow.FlowErrorKindProviderError: + code := strings.TrimSpace(flow.CallbackError) + if code == "" { + code = strings.TrimSpace(flowErr.Err.Error()) } - return nil, err - default: - } - input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ") - if errPrompt != nil { - return nil, errPrompt - } - parsed, errParse := misc.ParseOAuthCallback(input) - if errParse != nil { - return nil, errParse - } - if parsed == nil { - continue - } - manualDescription = parsed.ErrorDescription - result = &claude.OAuthResult{ - Code: parsed.Code, - State: parsed.State, - Error: parsed.Error, + if code == "" { + code = "oauth_error" + } + return nil, claude.NewOAuthError(code, "", http.StatusBadRequest) + case oauthflow.FlowErrorKindInvalidState: + return nil, claude.NewAuthenticationError(claude.ErrInvalidState, err) + case oauthflow.FlowErrorKindCodeExchangeFailed: + return nil, claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, err) } - break waitForCallback } + return nil, err } - - if result.Error != "" { - return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) + if flow == nil || flow.Token == nil { + return nil, fmt.Errorf("claude authentication failed: missing token result") } - if result.State != state { - return nil, claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("state mismatch")) + email := "" + if flow.Token.Metadata != nil { + if raw, ok := flow.Token.Metadata["email"]; ok { + if s, okStr := raw.(string); okStr { + email = strings.TrimSpace(s) + } + } } - - log.Debug("Claude authorization code received; exchanging for tokens") - - authBundle, err := authSvc.ExchangeCodeForTokens(ctx, result.Code, state, pkceCodes) - if err != nil { - return nil, claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, err) + if email == "" { + return nil, fmt.Errorf("claude token storage missing account information") } - tokenStorage := authSvc.CreateTokenStorage(authBundle) - - if tokenStorage == nil || tokenStorage.Email == "" { - return nil, fmt.Errorf("claude token storage missing account information") + tokenStorage := &claude.ClaudeTokenStorage{ + AccessToken: flow.Token.AccessToken, + RefreshToken: flow.Token.RefreshToken, + LastRefresh: time.Now().Format(time.RFC3339), + Email: email, + Expire: flow.Token.ExpiresAt, } - fileName := fmt.Sprintf("claude-%s.json", tokenStorage.Email) + fileName := fmt.Sprintf("claude-%s.json", email) metadata := map[string]any{ - "email": tokenStorage.Email, + "email": email, } fmt.Println("Claude authentication successful") - if authBundle.APIKey != "" { - fmt.Println("Claude API key obtained and stored") - } return &coreauth.Auth{ ID: fileName, diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index 999925251..c48840e93 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "fmt" "net/http" "strings" @@ -11,7 +12,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" // legacy client removed "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" @@ -47,154 +48,104 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts opts = &LoginOptions{} } - pkceCodes, err := codex.GeneratePKCECodes() - if err != nil { - return nil, fmt.Errorf("codex pkce generation failed: %w", err) - } - - state, err := misc.GenerateRandomState() - if err != nil { - return nil, fmt.Errorf("codex state generation failed: %w", err) - } - - oauthServer := codex.NewOAuthServer(a.CallbackPort) - if err = oauthServer.Start(); err != nil { - if strings.Contains(err.Error(), "already in use") { - return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err) - } - return nil, codex.NewAuthenticationError(codex.ErrServerStartFailed, err) - } - defer func() { - stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if stopErr := oauthServer.Stop(stopCtx); stopErr != nil { - log.Warnf("codex oauth server stop error: %v", stopErr) - } - }() - + desiredPort := a.CallbackPort authSvc := codex.NewCodexAuth(cfg) + provider := codex.NewOAuthProvider(authSvc) + flow, err := oauthflow.RunAuthCodeFlow(ctx, provider, oauthflow.AuthCodeFlowOptions{ + DesiredPort: desiredPort, + CallbackPath: "/auth/callback", + Timeout: 5 * time.Minute, + OnAuthURL: func(authURL string, callbackPort int, redirectURI string) { + if desiredPort != 0 && callbackPort != desiredPort { + log.Warnf("codex oauth callback port %d is busy; falling back to an ephemeral port", desiredPort) + } - authURL, err := authSvc.GenerateAuthURL(state, pkceCodes) - if err != nil { - return nil, fmt.Errorf("codex authorization url generation failed: %w", err) - } - - if !opts.NoBrowser { - fmt.Println("Opening browser for Codex authentication") - if !browser.IsAvailable() { - log.Warn("No browser available; please open the URL manually") - util.PrintSSHTunnelInstructions(a.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } else if err = browser.OpenURL(authURL); err != nil { - log.Warnf("Failed to open browser automatically: %v", err) - util.PrintSSHTunnelInstructions(a.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - } else { - util.PrintSSHTunnelInstructions(a.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } - - fmt.Println("Waiting for Codex authentication callback...") - - callbackCh := make(chan *codex.OAuthResult, 1) - callbackErrCh := make(chan error, 1) - manualDescription := "" - - go func() { - result, errWait := oauthServer.WaitForCallback(5 * time.Minute) - if errWait != nil { - callbackErrCh <- errWait - return - } - callbackCh <- result - }() - - var result *codex.OAuthResult - var manualPromptTimer *time.Timer - var manualPromptC <-chan time.Time - if opts.Prompt != nil { - manualPromptTimer = time.NewTimer(15 * time.Second) - manualPromptC = manualPromptTimer.C - defer manualPromptTimer.Stop() - } + if !opts.NoBrowser { + fmt.Println("Opening browser for Codex authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if errOpen := browser.OpenURL(authURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } -waitForCallback: - for { - select { - case result = <-callbackCh: - break waitForCallback - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { + fmt.Println("Waiting for Codex authentication callback...") + }, + }) + if err != nil { + var flowErr *oauthflow.FlowError + if errors.As(err, &flowErr) && flowErr != nil { + switch flowErr.Kind { + case oauthflow.FlowErrorKindPortInUse: + return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err) + case oauthflow.FlowErrorKindServerStartFailed: + return nil, codex.NewAuthenticationError(codex.ErrServerStartFailed, err) + case oauthflow.FlowErrorKindAuthorizeURLFailed: + return nil, fmt.Errorf("codex authorization url generation failed: %w", flowErr.Err) + case oauthflow.FlowErrorKindCallbackTimeout: return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) - } - return nil, err - case <-manualPromptC: - manualPromptC = nil - if manualPromptTimer != nil { - manualPromptTimer.Stop() - } - select { - case result = <-callbackCh: - break waitForCallback - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + case oauthflow.FlowErrorKindProviderError: + code := strings.TrimSpace(flow.CallbackError) + if code == "" { + code = strings.TrimSpace(flowErr.Err.Error()) } - return nil, err - default: - } - input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ") - if errPrompt != nil { - return nil, errPrompt - } - parsed, errParse := misc.ParseOAuthCallback(input) - if errParse != nil { - return nil, errParse - } - if parsed == nil { - continue - } - manualDescription = parsed.ErrorDescription - result = &codex.OAuthResult{ - Code: parsed.Code, - State: parsed.State, - Error: parsed.Error, + if code == "" { + code = "oauth_error" + } + return nil, codex.NewOAuthError(code, "", http.StatusBadRequest) + case oauthflow.FlowErrorKindInvalidState: + return nil, codex.NewAuthenticationError(codex.ErrInvalidState, err) + case oauthflow.FlowErrorKindCodeExchangeFailed: + return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) } - break waitForCallback } + return nil, err } - - if result.Error != "" { - return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) + if flow == nil || flow.Token == nil { + return nil, fmt.Errorf("codex authentication failed: missing token result") } - if result.State != state { - return nil, codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("state mismatch")) + email := "" + accountID := "" + if flow.Token.Metadata != nil { + if raw, ok := flow.Token.Metadata["email"]; ok { + if s, okStr := raw.(string); okStr { + email = strings.TrimSpace(s) + } + } + if raw, ok := flow.Token.Metadata["account_id"]; ok { + if s, okStr := raw.(string); okStr { + accountID = strings.TrimSpace(s) + } + } } - - log.Debug("Codex authorization code received; exchanging for tokens") - - authBundle, err := authSvc.ExchangeCodeForTokens(ctx, result.Code, pkceCodes) - if err != nil { - return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) + if email == "" { + return nil, fmt.Errorf("codex token storage missing account information") } - tokenStorage := authSvc.CreateTokenStorage(authBundle) - - if tokenStorage == nil || tokenStorage.Email == "" { - return nil, fmt.Errorf("codex token storage missing account information") + tokenStorage := &codex.CodexTokenStorage{ + IDToken: flow.Token.IDToken, + AccessToken: flow.Token.AccessToken, + RefreshToken: flow.Token.RefreshToken, + AccountID: accountID, + LastRefresh: time.Now().Format(time.RFC3339), + Email: email, + Expire: flow.Token.ExpiresAt, } - fileName := fmt.Sprintf("codex-%s.json", tokenStorage.Email) + fileName := fmt.Sprintf("codex-%s.json", email) metadata := map[string]any{ - "email": tokenStorage.Email, + "email": email, } fmt.Println("Codex authentication successful") - if authBundle.APIKey != "" { - fmt.Println("Codex API key obtained and stored") - } return &coreauth.Auth{ ID: fileName, diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index 3fd82f1d3..0d55fe9cb 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "fmt" "strings" "time" @@ -9,7 +10,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/oauthflow" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" @@ -43,129 +44,99 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts } authSvc := iflow.NewIFlowAuth(cfg) + desiredPort := iflow.CallbackPort + provider := iflow.NewOAuthProvider(authSvc) + + flow, err := oauthflow.RunAuthCodeFlow(ctx, provider, oauthflow.AuthCodeFlowOptions{ + DesiredPort: desiredPort, + CallbackPath: "/oauth2callback", + Timeout: 5 * time.Minute, + OnAuthURL: func(authURL string, callbackPort int, redirectURI string) { + if desiredPort != 0 && callbackPort != desiredPort { + log.Warnf("iflow oauth callback port %d is busy; falling back to an ephemeral port", desiredPort) + } - oauthServer := iflow.NewOAuthServer(iflow.CallbackPort) - if err := oauthServer.Start(); err != nil { - if strings.Contains(err.Error(), "already in use") { - return nil, fmt.Errorf("iflow authentication server port in use: %w", err) - } - return nil, fmt.Errorf("iflow authentication server failed: %w", err) - } - defer func() { - stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if stopErr := oauthServer.Stop(stopCtx); stopErr != nil { - log.Warnf("iflow oauth server stop error: %v", stopErr) - } - }() + if !opts.NoBrowser { + fmt.Println("Opening browser for iFlow authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if errOpen := browser.OpenURL(authURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } - state, err := misc.GenerateRandomState() + fmt.Println("Waiting for iFlow authentication callback...") + }, + }) if err != nil { - return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err) - } - - authURL, redirectURI := authSvc.AuthorizationURL(state, iflow.CallbackPort) - - if !opts.NoBrowser { - fmt.Println("Opening browser for iFlow authentication") - if !browser.IsAvailable() { - log.Warn("No browser available; please open the URL manually") - util.PrintSSHTunnelInstructions(iflow.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) - } else if err = browser.OpenURL(authURL); err != nil { - log.Warnf("Failed to open browser automatically: %v", err) - util.PrintSSHTunnelInstructions(iflow.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + var flowErr *oauthflow.FlowError + if errors.As(err, &flowErr) && flowErr != nil { + switch flowErr.Kind { + case oauthflow.FlowErrorKindPortInUse: + return nil, fmt.Errorf("iflow authentication server port in use: %w", err) + case oauthflow.FlowErrorKindServerStartFailed: + return nil, fmt.Errorf("iflow authentication server failed: %w", err) + case oauthflow.FlowErrorKindCallbackTimeout: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case oauthflow.FlowErrorKindProviderError: + if flow != nil && flow.CallbackError != "" { + return nil, fmt.Errorf("iflow auth: provider returned error %s", flow.CallbackError) + } + return nil, fmt.Errorf("iflow auth: provider returned error") + case oauthflow.FlowErrorKindInvalidState: + return nil, fmt.Errorf("iflow auth: state mismatch") + case oauthflow.FlowErrorKindCodeExchangeFailed: + return nil, fmt.Errorf("iflow authentication failed: %w", flowErr.Err) + } } - } else { - util.PrintSSHTunnelInstructions(iflow.CallbackPort) - fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + return nil, err } - - fmt.Println("Waiting for iFlow authentication callback...") - - callbackCh := make(chan *iflow.OAuthResult, 1) - callbackErrCh := make(chan error, 1) - - go func() { - result, errWait := oauthServer.WaitForCallback(5 * time.Minute) - if errWait != nil { - callbackErrCh <- errWait - return - } - callbackCh <- result - }() - - var result *iflow.OAuthResult - var manualPromptTimer *time.Timer - var manualPromptC <-chan time.Time - if opts.Prompt != nil { - manualPromptTimer = time.NewTimer(15 * time.Second) - manualPromptC = manualPromptTimer.C - defer manualPromptTimer.Stop() + if flow == nil || flow.Token == nil { + return nil, fmt.Errorf("iflow authentication failed: missing token result") } -waitForCallback: - for { - select { - case result = <-callbackCh: - break waitForCallback - case err = <-callbackErrCh: - return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) - case <-manualPromptC: - manualPromptC = nil - if manualPromptTimer != nil { - manualPromptTimer.Stop() + email := "" + apiKey := "" + if flow.Token.Metadata != nil { + if raw, ok := flow.Token.Metadata["email"]; ok { + if s, okStr := raw.(string); okStr { + email = strings.TrimSpace(s) } - select { - case result = <-callbackCh: - break waitForCallback - case err = <-callbackErrCh: - return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) - default: - } - input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ") - if errPrompt != nil { - return nil, errPrompt - } - parsed, errParse := misc.ParseOAuthCallback(input) - if errParse != nil { - return nil, errParse - } - if parsed == nil { - continue - } - result = &iflow.OAuthResult{ - Code: parsed.Code, - State: parsed.State, - Error: parsed.Error, + } + if raw, ok := flow.Token.Metadata["api_key"]; ok { + if s, okStr := raw.(string); okStr { + apiKey = strings.TrimSpace(s) } - break waitForCallback } } - if result.Error != "" { - return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) - } - if result.State != state { - return nil, fmt.Errorf("iflow auth: state mismatch") + if email == "" { + return nil, fmt.Errorf("iflow authentication failed: missing account identifier") } - - tokenData, err := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI) - if err != nil { - return nil, fmt.Errorf("iflow authentication failed: %w", err) + if apiKey == "" { + return nil, fmt.Errorf("iflow authentication failed: missing api key") } - tokenStorage := authSvc.CreateTokenStorage(tokenData) - - email := strings.TrimSpace(tokenStorage.Email) - if email == "" { - return nil, fmt.Errorf("iflow authentication failed: missing account identifier") + tokenStorage := &iflow.IFlowTokenStorage{ + AccessToken: flow.Token.AccessToken, + RefreshToken: flow.Token.RefreshToken, + LastRefresh: time.Now().Format(time.RFC3339), + Expire: flow.Token.ExpiresAt, + APIKey: apiKey, + Email: email, + TokenType: flow.Token.TokenType, } fileName := fmt.Sprintf("iflow-%s-%d.json", email, time.Now().Unix()) metadata := map[string]any{ "email": email, - "api_key": tokenStorage.APIKey, + "api_key": apiKey, "access_token": tokenStorage.AccessToken, "refresh_token": tokenStorage.RefreshToken, "expired": tokenStorage.Expire, @@ -180,7 +151,7 @@ waitForCallback: Storage: tokenStorage, Metadata: metadata, Attributes: map[string]string{ - "api_key": tokenStorage.APIKey, + "api_key": apiKey, }, }, nil } diff --git a/sdk/auth/qwen.go b/sdk/auth/qwen.go index 151fba681..3de778eb0 100644 --- a/sdk/auth/qwen.go +++ b/sdk/auth/qwen.go @@ -44,12 +44,17 @@ func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts authSvc := qwen.NewQwenAuth(cfg) - deviceFlow, err := authSvc.InitiateDeviceFlow(ctx) + deviceProvider := qwen.NewDeviceOAuthProvider(authSvc) + + deviceFlow, err := deviceProvider.DeviceAuthorize(ctx) if err != nil { return nil, fmt.Errorf("qwen device flow initiation failed: %w", err) } - authURL := deviceFlow.VerificationURIComplete + authURL := strings.TrimSpace(deviceFlow.VerificationURIComplete) + if authURL == "" { + authURL = strings.TrimSpace(deviceFlow.VerificationURI) + } if !opts.NoBrowser { fmt.Println("Opening browser for Qwen authentication") @@ -66,11 +71,24 @@ func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for Qwen authentication...") - tokenData, err := authSvc.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) + tokenResult, err := deviceProvider.DevicePoll(ctx, deviceFlow) if err != nil { return nil, fmt.Errorf("qwen authentication failed: %w", err) } + tokenData := &qwen.QwenTokenData{ + AccessToken: tokenResult.AccessToken, + RefreshToken: tokenResult.RefreshToken, + TokenType: tokenResult.TokenType, + Expire: tokenResult.ExpiresAt, + } + if tokenResult.Metadata != nil { + if raw, ok := tokenResult.Metadata["resource_url"]; ok { + if val, okStr := raw.(string); okStr { + tokenData.ResourceURL = strings.TrimSpace(val) + } + } + } tokenStorage := authSvc.CreateTokenStorage(tokenData) email := ""