From 3374292cb312a88677dc04f2b5d65e8bb8108f99 Mon Sep 17 00:00:00 2001 From: Evan Date: Mon, 29 Dec 2025 07:24:00 +0700 Subject: [PATCH 01/17] feat(quota): implement quota management and fetching for multiple providers - Added quota fetching endpoints in management handler. - Introduced quota manager to orchestrate fetching and caching of quota data. - Implemented Antigravity and Codex quota fetchers to retrieve usage information. - Created caching mechanism for quota data to improve performance. - Added types and structures for handling quota responses and subscription information. - Implemented refresh functionality for quota data to ensure up-to-date information. --- internal/api/handlers/management/handler.go | 8 +- .../api/handlers/management/quota_fetchers.go | 117 ++++++ internal/api/server.go | 7 + internal/quota/antigravity.go | 345 +++++++++++++++++ internal/quota/cache.go | 119 ++++++ internal/quota/codex.go | 223 +++++++++++ internal/quota/fetcher.go | 25 ++ internal/quota/manager.go | 356 ++++++++++++++++++ internal/quota/types.go | 119 ++++++ 9 files changed, 1318 insertions(+), 1 deletion(-) create mode 100644 internal/api/handlers/management/quota_fetchers.go create mode 100644 internal/quota/antigravity.go create mode 100644 internal/quota/cache.go create mode 100644 internal/quota/codex.go create mode 100644 internal/quota/fetcher.go create mode 100644 internal/quota/manager.go create mode 100644 internal/quota/types.go diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index d3ccbda6c..3a0f3d99b 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -15,6 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/quota" "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -36,6 +37,7 @@ type Handler struct { authManager *coreauth.Manager usageStats *usage.RequestStatistics tokenStore coreauth.Store + quotaManager *quota.Manager localPassword string allowRemoteOverride bool envSecret string @@ -47,13 +49,17 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD") envSecret = strings.TrimSpace(envSecret) + tokenStore := sdkAuth.GetTokenStore() + quotaManager := quota.NewManager(tokenStore, nil) + return &Handler{ cfg: cfg, configFilePath: configFilePath, failedAttempts: make(map[string]*attemptInfo), authManager: manager, usageStats: usage.GetRequestStatistics(), - tokenStore: sdkAuth.GetTokenStore(), + tokenStore: tokenStore, + quotaManager: quotaManager, allowRemoteOverride: envSecret != "", envSecret: envSecret, } diff --git a/internal/api/handlers/management/quota_fetchers.go b/internal/api/handlers/management/quota_fetchers.go new file mode 100644 index 000000000..530f4e375 --- /dev/null +++ b/internal/api/handlers/management/quota_fetchers.go @@ -0,0 +1,117 @@ +package management + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/quota" +) + +// GetAllQuotas returns quota for all connected accounts. +// GET /v0/management/quotas +func (h *Handler) GetAllQuotas(c *gin.Context) { + if h.quotaManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "quota manager not initialized"}) + return + } + + ctx := c.Request.Context() + quotas, err := h.quotaManager.FetchAllQuotas(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, quotas) +} + +// GetProviderQuotas returns quota for a specific provider. +// GET /v0/management/quotas/:provider +func (h *Handler) GetProviderQuotas(c *gin.Context) { + if h.quotaManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "quota manager not initialized"}) + return + } + + provider := c.Param("provider") + if provider == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "provider is required"}) + return + } + + ctx := c.Request.Context() + quotas, err := h.quotaManager.FetchProviderQuotas(ctx, provider) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, quotas) +} + +// GetAccountQuota returns quota for a specific account. +// GET /v0/management/quotas/:provider/:account +func (h *Handler) GetAccountQuota(c *gin.Context) { + if h.quotaManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "quota manager not initialized"}) + return + } + + provider := c.Param("provider") + account := c.Param("account") + if provider == "" || account == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "provider and account are required"}) + return + } + + ctx := c.Request.Context() + quotaResp, err := h.quotaManager.FetchAccountQuota(ctx, provider, account) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, quotaResp) +} + +// RefreshQuotas forces a quota refresh for all or specific providers. +// POST /v0/management/quotas/refresh +func (h *Handler) RefreshQuotas(c *gin.Context) { + if h.quotaManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "quota manager not initialized"}) + return + } + + var req quota.RefreshRequest + if err := c.ShouldBindJSON(&req); err != nil { + // Allow empty body - refresh all + req.Providers = nil + } + + ctx := c.Request.Context() + quotas, err := h.quotaManager.RefreshQuotas(ctx, req.Providers) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, quotas) +} + +// GetSubscriptionInfo returns subscription/tier info for Antigravity accounts. +// GET /v0/management/subscription-info +func (h *Handler) GetSubscriptionInfo(c *gin.Context) { + if h.quotaManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "quota manager not initialized"}) + return + } + + ctx := c.Request.Context() + info, err := h.quotaManager.GetSubscriptionInfo(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, info) +} \ No newline at end of file diff --git a/internal/api/server.go b/internal/api/server.go index 9a195db07..b9790d49c 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -595,6 +595,13 @@ func (s *Server) registerManagementRoutes() { mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) + + // Quota fetching endpoints + mgmt.GET("/quotas", s.mgmt.GetAllQuotas) + mgmt.GET("/quotas/:provider", s.mgmt.GetProviderQuotas) + mgmt.GET("/quotas/:provider/:account", s.mgmt.GetAccountQuota) + mgmt.POST("/quotas/refresh", s.mgmt.RefreshQuotas) + mgmt.GET("/subscription-info", s.mgmt.GetSubscriptionInfo) } } diff --git a/internal/quota/antigravity.go b/internal/quota/antigravity.go new file mode 100644 index 000000000..d0e35cb22 --- /dev/null +++ b/internal/quota/antigravity.go @@ -0,0 +1,345 @@ +package quota + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +const ( + antigravityQuotaAPIURL = "https://cloudcode-pa.googleapis.com/v1internal:fetchAvailableModels" + antigravityLoadProjectURL = "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" + antigravityAPIUserAgent = "antigravity/1.11.3 Darwin/arm64" + antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" + antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` +) + +// AntigravityFetcher implements quota fetching for Antigravity and Gemini-CLI providers. +// Both providers use the same Google Cloud Code API for quota information. +type AntigravityFetcher struct { + httpClient *http.Client +} + +// NewAntigravityFetcher creates a new Antigravity quota fetcher. +func NewAntigravityFetcher(httpClient *http.Client) *AntigravityFetcher { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + return &AntigravityFetcher{httpClient: httpClient} +} + +// Provider returns the primary provider name. +func (f *AntigravityFetcher) Provider() string { + return "antigravity" +} + +// SupportedProviders returns all provider names this fetcher supports. +func (f *AntigravityFetcher) SupportedProviders() []string { + return []string{"antigravity", "gemini-cli"} +} + +// CanFetch returns true if this fetcher can handle the given provider. +func (f *AntigravityFetcher) CanFetch(provider string) bool { + provider = strings.ToLower(strings.TrimSpace(provider)) + return provider == "antigravity" || provider == "gemini-cli" +} + +// FetchQuota fetches quota for an Antigravity or Gemini-CLI auth credential. +func (f *AntigravityFetcher) FetchQuota(ctx context.Context, auth *coreauth.Auth) (*ProviderQuotaData, error) { + if auth == nil { + return nil, fmt.Errorf("auth is nil") + } + + // Get access token from metadata + accessToken := f.extractAccessToken(auth) + if accessToken == "" { + return UnavailableQuota("no access token available"), nil + } + + // Get project ID from metadata or fetch it + projectID := f.extractProjectID(auth) + if projectID == "" { + // Try to fetch project ID + var err error + projectID, err = f.fetchProjectID(ctx, accessToken) + if err != nil { + log.Warnf("antigravity quota: failed to fetch project ID: %v", err) + return UnavailableQuota(fmt.Sprintf("failed to fetch project ID: %v", err)), nil + } + } + + // Fetch available models and quota info + quotaData, err := f.fetchQuotaData(ctx, accessToken, projectID) + if err != nil { + log.Warnf("antigravity quota: failed to fetch quota: %v", err) + return UnavailableQuota(fmt.Sprintf("failed to fetch quota: %v", err)), nil + } + + return quotaData, nil +} + +// extractAccessToken extracts the access token from auth metadata. +func (f *AntigravityFetcher) extractAccessToken(auth *coreauth.Auth) string { + if auth.Metadata == nil { + return "" + } + if token, ok := auth.Metadata["access_token"].(string); ok { + return strings.TrimSpace(token) + } + return "" +} + +// extractProjectID extracts the project ID from auth metadata. +func (f *AntigravityFetcher) extractProjectID(auth *coreauth.Auth) string { + if auth.Metadata == nil { + return "" + } + if pid, ok := auth.Metadata["project_id"].(string); ok { + return strings.TrimSpace(pid) + } + return "" +} + +// fetchProjectID fetches the project ID using the loadCodeAssist API. +func (f *AntigravityFetcher) fetchProjectID(ctx context.Context, accessToken string) (string, error) { + // Reference: Antigravity-Manager uses just {"metadata": {"ideType": "ANTIGRAVITY"}} + reqBody := map[string]any{ + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + }, + } + + rawBody, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, antigravityLoadProjectURL, strings.NewReader(string(rawBody))) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + + f.setRequestHeaders(req, accessToken) + + resp, err := f.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var result map[string]any + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + + // Extract project ID + if projectID, ok := result["cloudaicompanionProject"].(string); ok { + return strings.TrimSpace(projectID), nil + } + if projectMap, ok := result["cloudaicompanionProject"].(map[string]any); ok { + if id, okID := projectMap["id"].(string); okID { + return strings.TrimSpace(id), nil + } + } + + return "", fmt.Errorf("no cloudaicompanionProject in response") +} + +// fetchQuotaData fetches the quota data from the fetchAvailableModels API. +func (f *AntigravityFetcher) fetchQuotaData(ctx context.Context, accessToken, projectID string) (*ProviderQuotaData, error) { + // Reference: Antigravity-Manager uses just {"project": "project-id"} + reqBody := map[string]any{} + if projectID != "" { + reqBody["project"] = projectID + } + + rawBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, antigravityQuotaAPIURL, strings.NewReader(string(rawBody))) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + f.setRequestHeaders(req, accessToken) + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode == http.StatusForbidden { + return &ProviderQuotaData{ + Models: []ModelQuota{}, + LastUpdated: time.Now(), + IsForbidden: true, + Error: "account forbidden", + }, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var result fetchAvailableModelsResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return f.parseQuotaResponse(&result), nil +} + +// setRequestHeaders sets the standard headers for Antigravity API requests. +func (f *AntigravityFetcher) setRequestHeaders(req *http.Request, accessToken string) { + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) + req.Header.Set("Client-Metadata", antigravityClientMetadata) +} + +// API response structures + +type fetchAvailableModelsResponse struct { + // Models is a map from model name to model info (not an array!) + Models map[string]modelInfo `json:"models"` + CurrentTier *tierInfo `json:"currentTier"` + AvailableTiers []tierInfo `json:"availableTiers"` + ProjectID string `json:"cloudaicompanionProject"` + TierUpgradeURL string `json:"tierUpgradeUrl"` +} + +type modelInfo struct { + // QuotaInfo contains remaining quota information + QuotaInfo *quotaInfo `json:"quotaInfo"` +} + +type quotaInfo struct { + RemainingFraction *float64 `json:"remainingFraction"` + ResetTime string `json:"resetTime"` +} + +type tierInfo struct { + TierID string `json:"tierId"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + IsDefault bool `json:"isDefault"` +} + +// parseQuotaResponse converts the API response to our unified ProviderQuotaData format. +func (f *AntigravityFetcher) parseQuotaResponse(resp *fetchAvailableModelsResponse) *ProviderQuotaData { + var models []ModelQuota + + // Models is a map from model name to model info + for modelName, modelInfo := range resp.Models { + mq := ModelQuota{ + Name: modelName, + Percentage: -1, // Default to unavailable + } + + if modelInfo.QuotaInfo != nil { + if modelInfo.QuotaInfo.RemainingFraction != nil { + // remainingFraction is 0-1, convert to percentage + mq.Percentage = *modelInfo.QuotaInfo.RemainingFraction * 100 + } + mq.ResetTime = modelInfo.QuotaInfo.ResetTime + } + + models = append(models, mq) + } + + data := &ProviderQuotaData{ + Models: models, + LastUpdated: time.Now(), + IsForbidden: false, + } + + // Set plan type from current tier + if resp.CurrentTier != nil { + data.PlanType = resp.CurrentTier.TierID + data.Extra = map[string]any{ + "tier_name": resp.CurrentTier.DisplayName, + "tier_description": resp.CurrentTier.Description, + "project_id": resp.ProjectID, + } + if resp.TierUpgradeURL != "" { + data.Extra["upgrade_url"] = resp.TierUpgradeURL + } + } + + return data +} + +// GetSubscriptionInfo fetches subscription/tier information for an Antigravity account. +func (f *AntigravityFetcher) GetSubscriptionInfo(ctx context.Context, auth *coreauth.Auth) (*SubscriptionInfo, error) { + if auth == nil { + return nil, fmt.Errorf("auth is nil") + } + + accessToken := f.extractAccessToken(auth) + if accessToken == "" { + return nil, fmt.Errorf("no access token available") + } + + projectID := f.extractProjectID(auth) + if projectID == "" { + var err error + projectID, err = f.fetchProjectID(ctx, accessToken) + if err != nil { + return nil, fmt.Errorf("failed to fetch project ID: %w", err) + } + } + + quotaData, err := f.fetchQuotaData(ctx, accessToken, projectID) + if err != nil { + return nil, fmt.Errorf("failed to fetch quota data: %w", err) + } + + info := &SubscriptionInfo{ + ProjectID: projectID, + } + + if quotaData.Extra != nil { + if tierName, ok := quotaData.Extra["tier_name"].(string); ok { + info.CurrentTier = &SubscriptionTier{ + ID: quotaData.PlanType, + Name: tierName, + } + if desc, ok := quotaData.Extra["tier_description"].(string); ok { + info.CurrentTier.Description = desc + } + } + if url, ok := quotaData.Extra["upgrade_url"].(string); ok { + info.UpgradeURL = url + if info.CurrentTier != nil { + info.CurrentTier.UpgradeURL = url + } + } + } + + return info, nil +} \ No newline at end of file diff --git a/internal/quota/cache.go b/internal/quota/cache.go new file mode 100644 index 000000000..94c43e7f5 --- /dev/null +++ b/internal/quota/cache.go @@ -0,0 +1,119 @@ +package quota + +import ( + "sync" + "time" +) + +// DefaultCacheTTL is the default time-to-live for cached quota data. +const DefaultCacheTTL = 60 * time.Second + +// CacheEntry represents a cached quota entry with expiration. +type CacheEntry struct { + Data *ProviderQuotaData + ExpiresAt time.Time +} + +// IsExpired returns true if the cache entry has expired. +func (e *CacheEntry) IsExpired() bool { + return time.Now().After(e.ExpiresAt) +} + +// QuotaCache provides TTL-based caching for quota data. +type QuotaCache struct { + mu sync.RWMutex + entries map[string]*CacheEntry + ttl time.Duration +} + +// NewQuotaCache creates a new quota cache with the given TTL. +func NewQuotaCache(ttl time.Duration) *QuotaCache { + if ttl <= 0 { + ttl = DefaultCacheTTL + } + return &QuotaCache{ + entries: make(map[string]*CacheEntry), + ttl: ttl, + } +} + +// cacheKey generates a cache key from provider and account ID. +func cacheKey(provider, accountID string) string { + return provider + ":" + accountID +} + +// Get retrieves quota data from cache if available and not expired. +func (c *QuotaCache) Get(provider, accountID string) (*ProviderQuotaData, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + key := cacheKey(provider, accountID) + entry, exists := c.entries[key] + if !exists || entry.IsExpired() { + return nil, false + } + return entry.Data, true +} + +// Set stores quota data in cache with the configured TTL. +func (c *QuotaCache) Set(provider, accountID string, data *ProviderQuotaData) { + c.mu.Lock() + defer c.mu.Unlock() + + key := cacheKey(provider, accountID) + c.entries[key] = &CacheEntry{ + Data: data, + ExpiresAt: time.Now().Add(c.ttl), + } +} + +// Invalidate removes a specific entry from cache. +func (c *QuotaCache) Invalidate(provider, accountID string) { + c.mu.Lock() + defer c.mu.Unlock() + + key := cacheKey(provider, accountID) + delete(c.entries, key) +} + +// InvalidateProvider removes all entries for a specific provider. +func (c *QuotaCache) InvalidateProvider(provider string) { + c.mu.Lock() + defer c.mu.Unlock() + + prefix := provider + ":" + for key := range c.entries { + if len(key) > len(prefix) && key[:len(prefix)] == prefix { + delete(c.entries, key) + } + } +} + +// InvalidateAll clears the entire cache. +func (c *QuotaCache) InvalidateAll() { + c.mu.Lock() + defer c.mu.Unlock() + c.entries = make(map[string]*CacheEntry) +} + +// Cleanup removes expired entries from the cache. +func (c *QuotaCache) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for key, entry := range c.entries { + if now.After(entry.ExpiresAt) { + delete(c.entries, key) + } + } +} + +// SetTTL updates the TTL for new cache entries. +func (c *QuotaCache) SetTTL(ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + if ttl > 0 { + c.ttl = ttl + } +} \ No newline at end of file diff --git a/internal/quota/codex.go b/internal/quota/codex.go new file mode 100644 index 000000000..6612ced2f --- /dev/null +++ b/internal/quota/codex.go @@ -0,0 +1,223 @@ +package quota + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +const ( + codexUsageAPIURL = "https://chatgpt.com/backend-api/wham/usage" +) + +// CodexFetcher implements quota fetching for Codex/OpenAI provider. +type CodexFetcher struct { + httpClient *http.Client +} + +// NewCodexFetcher creates a new Codex quota fetcher. +func NewCodexFetcher(httpClient *http.Client) *CodexFetcher { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + return &CodexFetcher{httpClient: httpClient} +} + +// Provider returns the provider name. +func (f *CodexFetcher) Provider() string { + return "codex" +} + +// SupportedProviders returns all provider names this fetcher supports. +func (f *CodexFetcher) SupportedProviders() []string { + return []string{"codex"} +} + +// CanFetch returns true if this fetcher can handle the given provider. +func (f *CodexFetcher) CanFetch(provider string) bool { + provider = strings.ToLower(strings.TrimSpace(provider)) + return provider == "codex" +} + +// FetchQuota fetches quota for a Codex auth credential. +func (f *CodexFetcher) FetchQuota(ctx context.Context, auth *coreauth.Auth) (*ProviderQuotaData, error) { + if auth == nil { + return nil, fmt.Errorf("auth is nil") + } + + // Get access token from metadata + accessToken := f.extractAccessToken(auth) + if accessToken == "" { + return UnavailableQuota("no access token available"), nil + } + + // Fetch usage data + quotaData, err := f.fetchUsageData(ctx, accessToken) + if err != nil { + log.Warnf("codex quota: failed to fetch usage: %v", err) + return UnavailableQuota(fmt.Sprintf("failed to fetch usage: %v", err)), nil + } + + return quotaData, nil +} + +// extractAccessToken extracts the access token from auth metadata. +func (f *CodexFetcher) extractAccessToken(auth *coreauth.Auth) string { + if auth.Metadata == nil { + return "" + } + if token, ok := auth.Metadata["access_token"].(string); ok { + return strings.TrimSpace(token) + } + return "" +} + +// fetchUsageData fetches usage data from the ChatGPT API. +func (f *CodexFetcher) fetchUsageData(ctx context.Context, accessToken string) (*ProviderQuotaData, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, codexUsageAPIURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36") + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized { + return &ProviderQuotaData{ + Models: []ModelQuota{}, + LastUpdated: time.Now(), + IsForbidden: true, + Error: fmt.Sprintf("HTTP %d", resp.StatusCode), + }, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var result codexUsageResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return f.parseUsageResponse(&result), nil +} + +// API response structures based on Quotio's CodexCLIQuotaFetcher + +type codexUsageResponse struct { + RateLimit *rateLimitInfo `json:"rate_limit"` + Plan *planInfo `json:"plan"` +} + +type rateLimitInfo struct { + PrimaryWindow *windowInfo `json:"primary_window"` + SecondaryWindow *windowInfo `json:"secondary_window"` +} + +type windowInfo struct { + Used int64 `json:"used"` + Limit int64 `json:"limit"` + ResetAt string `json:"reset_at"` +} + +type planInfo struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` +} + +// parseUsageResponse converts the API response to our unified ProviderQuotaData format. +func (f *CodexFetcher) parseUsageResponse(resp *codexUsageResponse) *ProviderQuotaData { + var models []ModelQuota + + if resp.RateLimit != nil { + // Primary window (session usage) + if resp.RateLimit.PrimaryWindow != nil { + pw := resp.RateLimit.PrimaryWindow + percentage := float64(100) + if pw.Limit > 0 { + remaining := pw.Limit - pw.Used + percentage = float64(remaining) / float64(pw.Limit) * 100 + if percentage < 0 { + percentage = 0 + } + } + used := pw.Used + limit := pw.Limit + remaining := pw.Limit - pw.Used + if remaining < 0 { + remaining = 0 + } + models = append(models, ModelQuota{ + Name: "codex-session", + Percentage: percentage, + ResetTime: pw.ResetAt, + Used: &used, + Limit: &limit, + Remaining: &remaining, + }) + } + + // Secondary window (weekly usage) + if resp.RateLimit.SecondaryWindow != nil { + sw := resp.RateLimit.SecondaryWindow + percentage := float64(100) + if sw.Limit > 0 { + remaining := sw.Limit - sw.Used + percentage = float64(remaining) / float64(sw.Limit) * 100 + if percentage < 0 { + percentage = 0 + } + } + used := sw.Used + limit := sw.Limit + remaining := sw.Limit - sw.Used + if remaining < 0 { + remaining = 0 + } + models = append(models, ModelQuota{ + Name: "codex-weekly", + Percentage: percentage, + ResetTime: sw.ResetAt, + Used: &used, + Limit: &limit, + Remaining: &remaining, + }) + } + } + + data := &ProviderQuotaData{ + Models: models, + LastUpdated: time.Now(), + IsForbidden: false, + } + + // Set plan type + if resp.Plan != nil { + data.PlanType = resp.Plan.Name + data.Extra = map[string]any{ + "plan_display_name": resp.Plan.DisplayName, + } + } + + return data +} \ No newline at end of file diff --git a/internal/quota/fetcher.go b/internal/quota/fetcher.go new file mode 100644 index 000000000..08cda0d60 --- /dev/null +++ b/internal/quota/fetcher.go @@ -0,0 +1,25 @@ +package quota + +import ( + "context" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// Fetcher defines the interface for quota fetchers. +// Each provider implements this interface to fetch quota data. +type Fetcher interface { + // Provider returns the provider name (e.g., "antigravity", "codex"). + Provider() string + + // SupportedProviders returns all provider names this fetcher supports. + // For example, antigravity fetcher may support both "antigravity" and "gemini-cli". + SupportedProviders() []string + + // FetchQuota fetches quota for a single auth credential. + // Returns nil ProviderQuotaData if the provider doesn't support quota fetching. + FetchQuota(ctx context.Context, auth *coreauth.Auth) (*ProviderQuotaData, error) + + // CanFetch returns true if this fetcher can handle the given provider. + CanFetch(provider string) bool +} \ No newline at end of file diff --git a/internal/quota/manager.go b/internal/quota/manager.go new file mode 100644 index 000000000..788b6eaa5 --- /dev/null +++ b/internal/quota/manager.go @@ -0,0 +1,356 @@ +package quota + +import ( + "context" + "net/http" + "strings" + "sync" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// unsupportedProviders lists providers that don't have quota APIs. +var unsupportedProviders = map[string]bool{ + "claude": true, + "gemini": true, // API key-based Gemini (pay-per-use) + "vertex": true, + "iflow": true, + "qwen": true, + "aistudio": true, +} + +// Manager orchestrates quota fetching for all providers. +type Manager struct { + mu sync.RWMutex + fetchers []Fetcher + cache *QuotaCache + authStore coreauth.Store + httpClient *http.Client +} + +// NewManager creates a new quota manager with the given auth store. +func NewManager(authStore coreauth.Store, httpClient *http.Client) *Manager { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + + m := &Manager{ + fetchers: make([]Fetcher, 0), + cache: NewQuotaCache(DefaultCacheTTL), + authStore: authStore, + httpClient: httpClient, + } + + // Register default fetchers + m.RegisterFetcher(NewAntigravityFetcher(httpClient)) + m.RegisterFetcher(NewCodexFetcher(httpClient)) + + return m +} + +// RegisterFetcher adds a quota fetcher to the manager. +func (m *Manager) RegisterFetcher(fetcher Fetcher) { + if fetcher == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + m.fetchers = append(m.fetchers, fetcher) +} + +// SetCacheTTL updates the cache TTL. +func (m *Manager) SetCacheTTL(ttl time.Duration) { + m.cache.SetTTL(ttl) +} + +// SetAuthStore updates the auth store. +func (m *Manager) SetAuthStore(store coreauth.Store) { + m.mu.Lock() + defer m.mu.Unlock() + m.authStore = store +} + +// getFetcherForProvider returns the fetcher that can handle the given provider. +func (m *Manager) getFetcherForProvider(provider string) Fetcher { + m.mu.RLock() + defer m.mu.RUnlock() + + provider = strings.ToLower(strings.TrimSpace(provider)) + for _, fetcher := range m.fetchers { + if fetcher.CanFetch(provider) { + return fetcher + } + } + return nil +} + +// isUnsupportedProvider returns true if the provider doesn't support quota fetching. +func isUnsupportedProvider(provider string) bool { + provider = strings.ToLower(strings.TrimSpace(provider)) + return unsupportedProviders[provider] +} + +// FetchAllQuotas fetches quota for all connected accounts. +func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { + m.mu.RLock() + authStore := m.authStore + m.mu.RUnlock() + + if authStore == nil { + return &QuotaResponse{ + Quotas: make(map[string]map[string]*ProviderQuotaData), + LastUpdated: time.Now(), + }, nil + } + + // List all auth records + auths, err := authStore.List(ctx) + if err != nil { + return nil, err + } + + response := &QuotaResponse{ + Quotas: make(map[string]map[string]*ProviderQuotaData), + LastUpdated: time.Now(), + } + + // Group auths by provider + for _, auth := range auths { + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + accountID := m.getAccountID(auth) + + if response.Quotas[provider] == nil { + response.Quotas[provider] = make(map[string]*ProviderQuotaData) + } + + quotaData := m.fetchQuotaForAuth(ctx, auth, false) + response.Quotas[provider][accountID] = quotaData + } + + return response, nil +} + +// FetchProviderQuotas fetches quota for all accounts of a specific provider. +func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*ProviderQuotaResponse, error) { + m.mu.RLock() + authStore := m.authStore + m.mu.RUnlock() + + provider = strings.ToLower(strings.TrimSpace(provider)) + + if authStore == nil { + return &ProviderQuotaResponse{ + Provider: provider, + Accounts: make(map[string]*ProviderQuotaData), + LastUpdated: time.Now(), + }, nil + } + + // List all auth records + auths, err := authStore.List(ctx) + if err != nil { + return nil, err + } + + response := &ProviderQuotaResponse{ + Provider: provider, + Accounts: make(map[string]*ProviderQuotaData), + LastUpdated: time.Now(), + } + + // Filter and fetch for this provider + for _, auth := range auths { + authProvider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if authProvider != provider { + continue + } + + accountID := m.getAccountID(auth) + quotaData := m.fetchQuotaForAuth(ctx, auth, false) + response.Accounts[accountID] = quotaData + } + + return response, nil +} + +// FetchAccountQuota fetches quota for a specific account. +func (m *Manager) FetchAccountQuota(ctx context.Context, provider, accountID string) (*AccountQuotaResponse, error) { + m.mu.RLock() + authStore := m.authStore + m.mu.RUnlock() + + provider = strings.ToLower(strings.TrimSpace(provider)) + + if authStore == nil { + return &AccountQuotaResponse{ + Provider: provider, + Account: accountID, + Quota: UnavailableQuota("auth store not configured"), + }, nil + } + + // List all auth records + auths, err := authStore.List(ctx) + if err != nil { + return nil, err + } + + // Find the matching auth + for _, auth := range auths { + authProvider := strings.ToLower(strings.TrimSpace(auth.Provider)) + authAccountID := m.getAccountID(auth) + if authProvider != provider || authAccountID != accountID { + continue + } + + quotaData := m.fetchQuotaForAuth(ctx, auth, false) + return &AccountQuotaResponse{ + Provider: provider, + Account: accountID, + Quota: quotaData, + }, nil + } + + return &AccountQuotaResponse{ + Provider: provider, + Account: accountID, + Quota: UnavailableQuota("account not found"), + }, nil +} + +// RefreshQuotas forces a refresh of quota data. +func (m *Manager) RefreshQuotas(ctx context.Context, providers []string) (*QuotaResponse, error) { + // Invalidate cache + if len(providers) == 0 { + m.cache.InvalidateAll() + } else { + for _, p := range providers { + m.cache.InvalidateProvider(p) + } + } + + // Fetch fresh data + return m.FetchAllQuotas(ctx) +} + +// GetSubscriptionInfo fetches subscription info for Antigravity/Gemini-CLI accounts. +func (m *Manager) GetSubscriptionInfo(ctx context.Context) (*SubscriptionInfoResponse, error) { + m.mu.RLock() + authStore := m.authStore + m.mu.RUnlock() + + response := &SubscriptionInfoResponse{ + Subscriptions: make(map[string]*SubscriptionInfo), + } + + if authStore == nil { + return response, nil + } + + // List all auth records + auths, err := authStore.List(ctx) + if err != nil { + return nil, err + } + + // Get subscription info for Antigravity/Gemini-CLI accounts + for _, auth := range auths { + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if provider != "antigravity" && provider != "gemini-cli" { + continue + } + + accountID := m.getAccountID(auth) + fetcher := m.getFetcherForProvider(provider) + if fetcher == nil { + continue + } + + // Type assert to get subscription info + if af, ok := fetcher.(*AntigravityFetcher); ok { + info, err := af.GetSubscriptionInfo(ctx, auth) + if err != nil { + log.Warnf("quota manager: failed to get subscription info for %s: %v", accountID, err) + continue + } + response.Subscriptions[accountID] = info + } + } + + return response, nil +} + +// fetchQuotaForAuth fetches quota for a single auth record. +func (m *Manager) fetchQuotaForAuth(ctx context.Context, auth *coreauth.Auth, bypassCache bool) *ProviderQuotaData { + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + accountID := m.getAccountID(auth) + + // Check cache first + if !bypassCache { + if cached, ok := m.cache.Get(provider, accountID); ok { + return cached + } + } + + // Check if provider is unsupported + if isUnsupportedProvider(provider) { + data := UnavailableQuota("quota API not available for this provider") + m.cache.Set(provider, accountID, data) + return data + } + + // Find appropriate fetcher + fetcher := m.getFetcherForProvider(provider) + if fetcher == nil { + data := UnavailableQuota("no fetcher available for this provider") + m.cache.Set(provider, accountID, data) + return data + } + + // Fetch quota + data, err := fetcher.FetchQuota(ctx, auth) + if err != nil { + log.Warnf("quota manager: fetch failed for %s/%s: %v", provider, accountID, err) + data = UnavailableQuota(err.Error()) + } + + if data == nil { + data = UnavailableQuota("fetcher returned nil") + } + + // Cache the result + m.cache.Set(provider, accountID, data) + + return data +} + +// getAccountID extracts a human-readable account ID from an auth record. +func (m *Manager) getAccountID(auth *coreauth.Auth) string { + // Try to get email from metadata + if auth.Metadata != nil { + if email, ok := auth.Metadata["email"].(string); ok && email != "" { + return strings.TrimSpace(email) + } + } + + // Try to get from attributes + if auth.Attributes != nil { + if email, ok := auth.Attributes["email"]; ok && email != "" { + return strings.TrimSpace(email) + } + } + + // Fall back to ID + if auth.ID != "" { + return auth.ID + } + + // Fall back to label + if auth.Label != "" { + return auth.Label + } + + return "unknown" +} \ No newline at end of file diff --git a/internal/quota/types.go b/internal/quota/types.go new file mode 100644 index 000000000..4bf50a5af --- /dev/null +++ b/internal/quota/types.go @@ -0,0 +1,119 @@ +// Package quota provides quota fetching functionality for various AI providers. +// It allows clients to check remaining usage quota for connected accounts. +package quota + +import ( + "time" +) + +// ModelQuota represents quota information for a single model or quota category. +type ModelQuota struct { + // Name is the model or category identifier (e.g., "gemini-3-pro-high", "codex-session") + Name string `json:"name"` + // Percentage is the remaining quota as a percentage (0-100). -1 means unavailable. + Percentage float64 `json:"percentage"` + // ResetTime is when the quota resets, in RFC3339 format. + ResetTime string `json:"reset_time,omitempty"` + // Used is the amount of quota used (optional, provider-specific). + Used *int64 `json:"used,omitempty"` + // Limit is the total quota limit (optional, provider-specific). + Limit *int64 `json:"limit,omitempty"` + // Remaining is the remaining quota (optional, provider-specific). + Remaining *int64 `json:"remaining,omitempty"` +} + +// ProviderQuotaData represents quota data for one account of a provider. +type ProviderQuotaData struct { + // Models contains quota info for each model/category. + Models []ModelQuota `json:"models"` + // LastUpdated is when the quota was last fetched. + LastUpdated time.Time `json:"last_updated"` + // IsForbidden indicates if the account has been blocked/forbidden. + IsForbidden bool `json:"is_forbidden"` + // PlanType is the subscription plan type (e.g., "g1-pro-tier", "plus"). + PlanType string `json:"plan_type,omitempty"` + // Error contains any error message from fetching quota. + Error string `json:"error,omitempty"` + // Extra contains provider-specific additional data. + Extra map[string]any `json:"extra,omitempty"` +} + +// QuotaResponse is the full API response for quota requests. +type QuotaResponse struct { + // Quotas maps provider -> account -> quota data. + Quotas map[string]map[string]*ProviderQuotaData `json:"quotas"` + // LastUpdated is when the overall response was generated. + LastUpdated time.Time `json:"last_updated"` +} + +// ProviderQuotaResponse is the response for a specific provider's quota. +type ProviderQuotaResponse struct { + // Provider is the provider name. + Provider string `json:"provider"` + // Accounts maps account ID -> quota data. + Accounts map[string]*ProviderQuotaData `json:"accounts"` + // LastUpdated is when the response was generated. + LastUpdated time.Time `json:"last_updated"` +} + +// AccountQuotaResponse is the response for a specific account's quota. +type AccountQuotaResponse struct { + // Provider is the provider name. + Provider string `json:"provider"` + // Account is the account identifier (email or ID). + Account string `json:"account"` + // Quota is the quota data for this account. + Quota *ProviderQuotaData `json:"quota"` +} + +// SubscriptionTier represents subscription tier information. +type SubscriptionTier struct { + // ID is the tier identifier (e.g., "g1-pro-tier"). + ID string `json:"id"` + // Name is the human-readable tier name. + Name string `json:"name"` + // Description provides details about the tier. + Description string `json:"description,omitempty"` + // IsDefault indicates if this is the default tier. + IsDefault bool `json:"is_default,omitempty"` + // UpgradeURL is the URL to upgrade the subscription. + UpgradeURL string `json:"upgrade_url,omitempty"` +} + +// SubscriptionInfo represents subscription info for an account. +type SubscriptionInfo struct { + // CurrentTier is the current subscription tier. + CurrentTier *SubscriptionTier `json:"current_tier,omitempty"` + // ProjectID is the GCP project ID (for Antigravity/Gemini-CLI). + ProjectID string `json:"project_id,omitempty"` + // UpgradeURL is the URL to upgrade the subscription. + UpgradeURL string `json:"upgrade_url,omitempty"` +} + +// SubscriptionInfoResponse is the response for subscription info. +type SubscriptionInfoResponse struct { + // Subscriptions maps account ID -> subscription info. + Subscriptions map[string]*SubscriptionInfo `json:"subscriptions"` +} + +// RefreshRequest is the request body for forcing quota refresh. +type RefreshRequest struct { + // Providers limits refresh to specific providers. If empty, refresh all. + Providers []string `json:"providers,omitempty"` +} + +// UnavailableQuota returns a ProviderQuotaData indicating quota is not available. +func UnavailableQuota(reason string) *ProviderQuotaData { + return &ProviderQuotaData{ + Models: []ModelQuota{ + { + Name: "quota", + Percentage: -1, + }, + }, + LastUpdated: time.Now(), + IsForbidden: false, + PlanType: "unavailable", + Error: reason, + } +} \ No newline at end of file From 38c721a7f51b4dc20ca4b6d09e0090c061b333f0 Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 08:13:58 +0700 Subject: [PATCH 02/17] feat(quota): enhance quota data structure with rate limit windows for Codex --- internal/quota/codex.go | 109 ++++++++++++++++++---------------------- internal/quota/types.go | 26 ++++++++-- 2 files changed, 71 insertions(+), 64 deletions(-) diff --git a/internal/quota/codex.go b/internal/quota/codex.go index 6612ced2f..314b84406 100644 --- a/internal/quota/codex.go +++ b/internal/quota/codex.go @@ -103,7 +103,6 @@ func (f *CodexFetcher) fetchUsageData(ctx context.Context, accessToken string) ( if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized { return &ProviderQuotaData{ - Models: []ModelQuota{}, LastUpdated: time.Now(), IsForbidden: true, Error: fmt.Sprintf("HTTP %d", resp.StatusCode), @@ -122,102 +121,90 @@ func (f *CodexFetcher) fetchUsageData(ctx context.Context, accessToken string) ( return f.parseUsageResponse(&result), nil } -// API response structures based on Quotio's CodexCLIQuotaFetcher +// API response structures matching the ChatGPT backend API type codexUsageResponse struct { - RateLimit *rateLimitInfo `json:"rate_limit"` - Plan *planInfo `json:"plan"` + PlanType string `json:"plan_type"` + RateLimit *rateLimitInfo `json:"rate_limit"` + CodeReviewRateLimit *rateLimitInfo `json:"code_review_rate_limit"` } type rateLimitInfo struct { + Allowed bool `json:"allowed"` + LimitReached bool `json:"limit_reached"` PrimaryWindow *windowInfo `json:"primary_window"` SecondaryWindow *windowInfo `json:"secondary_window"` } type windowInfo struct { - Used int64 `json:"used"` - Limit int64 `json:"limit"` - ResetAt string `json:"reset_at"` -} - -type planInfo struct { - Name string `json:"name"` - DisplayName string `json:"display_name"` + UsedPercent float64 `json:"used_percent"` + LimitWindowSeconds int64 `json:"limit_window_seconds"` + ResetAfterSeconds int64 `json:"reset_after_seconds"` + ResetAt int64 `json:"reset_at"` } // parseUsageResponse converts the API response to our unified ProviderQuotaData format. func (f *CodexFetcher) parseUsageResponse(resp *codexUsageResponse) *ProviderQuotaData { - var models []ModelQuota + windows := &RateLimitWindows{} if resp.RateLimit != nil { - // Primary window (session usage) + // Primary window (session usage - typically 5 hours) if resp.RateLimit.PrimaryWindow != nil { pw := resp.RateLimit.PrimaryWindow - percentage := float64(100) - if pw.Limit > 0 { - remaining := pw.Limit - pw.Used - percentage = float64(remaining) / float64(pw.Limit) * 100 - if percentage < 0 { - percentage = 0 - } + // remaining percentage = 100 - used_percent + percentage := 100.0 - pw.UsedPercent + if percentage < 0 { + percentage = 0 } - used := pw.Used - limit := pw.Limit - remaining := pw.Limit - pw.Used - if remaining < 0 { - remaining = 0 + + // Convert Unix timestamp to ISO 8601 string for ResetTime + resetTime := time.Unix(pw.ResetAt, 0).UTC().Format(time.RFC3339) + + windows.Session = &WindowQuota{ + Percentage: percentage, + ResetTime: resetTime, + WindowSeconds: pw.LimitWindowSeconds, } - models = append(models, ModelQuota{ - Name: "codex-session", - Percentage: percentage, - ResetTime: pw.ResetAt, - Used: &used, - Limit: &limit, - Remaining: &remaining, - }) } - // Secondary window (weekly usage) + // Secondary window (weekly usage - typically 7 days) if resp.RateLimit.SecondaryWindow != nil { sw := resp.RateLimit.SecondaryWindow - percentage := float64(100) - if sw.Limit > 0 { - remaining := sw.Limit - sw.Used - percentage = float64(remaining) / float64(sw.Limit) * 100 - if percentage < 0 { - percentage = 0 - } + // remaining percentage = 100 - used_percent + percentage := 100.0 - sw.UsedPercent + if percentage < 0 { + percentage = 0 } - used := sw.Used - limit := sw.Limit - remaining := sw.Limit - sw.Used - if remaining < 0 { - remaining = 0 + + // Convert Unix timestamp to ISO 8601 string for ResetTime + resetTime := time.Unix(sw.ResetAt, 0).UTC().Format(time.RFC3339) + + windows.Weekly = &WindowQuota{ + Percentage: percentage, + ResetTime: resetTime, + WindowSeconds: sw.LimitWindowSeconds, } - models = append(models, ModelQuota{ - Name: "codex-weekly", - Percentage: percentage, - ResetTime: sw.ResetAt, - Used: &used, - Limit: &limit, - Remaining: &remaining, - }) } } data := &ProviderQuotaData{ - Models: models, + Windows: windows, LastUpdated: time.Now(), IsForbidden: false, } - // Set plan type - if resp.Plan != nil { - data.PlanType = resp.Plan.Name + // Set plan type from root level + if resp.PlanType != "" { + data.PlanType = resp.PlanType + } + + // Add extra info about rate limit status + if resp.RateLimit != nil { data.Extra = map[string]any{ - "plan_display_name": resp.Plan.DisplayName, + "allowed": resp.RateLimit.Allowed, + "limit_reached": resp.RateLimit.LimitReached, } } return data -} \ No newline at end of file +} diff --git a/internal/quota/types.go b/internal/quota/types.go index 4bf50a5af..04fe27d01 100644 --- a/internal/quota/types.go +++ b/internal/quota/types.go @@ -22,10 +22,30 @@ type ModelQuota struct { Remaining *int64 `json:"remaining,omitempty"` } +// WindowQuota represents quota information for a rate limit window (used by Codex). +type WindowQuota struct { + // Percentage is the remaining quota as a percentage (0-100). + Percentage float64 `json:"percentage"` + // ResetTime is when the quota resets, in RFC3339 format. + ResetTime string `json:"reset_time,omitempty"` + // WindowSeconds is the duration of the rate limit window in seconds. + WindowSeconds int64 `json:"window_seconds,omitempty"` +} + +// RateLimitWindows represents rate limit windows for providers like Codex. +type RateLimitWindows struct { + // Session is the session-based rate limit window (e.g., 5 hours). + Session *WindowQuota `json:"session,omitempty"` + // Weekly is the weekly rate limit window (e.g., 7 days). + Weekly *WindowQuota `json:"weekly,omitempty"` +} + // ProviderQuotaData represents quota data for one account of a provider. type ProviderQuotaData struct { - // Models contains quota info for each model/category. - Models []ModelQuota `json:"models"` + // Models contains quota info for each model/category (used by Antigravity, Gemini-CLI, etc.). + Models []ModelQuota `json:"models,omitempty"` + // Windows contains rate limit window info (used by Codex). + Windows *RateLimitWindows `json:"windows,omitempty"` // LastUpdated is when the quota was last fetched. LastUpdated time.Time `json:"last_updated"` // IsForbidden indicates if the account has been blocked/forbidden. @@ -116,4 +136,4 @@ func UnavailableQuota(reason string) *ProviderQuotaData { PlanType: "unavailable", Error: reason, } -} \ No newline at end of file +} From 8f9c99c6ac86fa9a0fb039d5fe8f3fa6e6f7a332 Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 08:19:39 +0700 Subject: [PATCH 03/17] feat(quota): add error handling for unknown providers in quota fetching --- .../api/handlers/management/quota_fetchers.go | 17 +++- internal/quota/manager.go | 85 ++++++++++++++++--- 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/internal/api/handlers/management/quota_fetchers.go b/internal/api/handlers/management/quota_fetchers.go index 530f4e375..ecc5437d3 100644 --- a/internal/api/handlers/management/quota_fetchers.go +++ b/internal/api/handlers/management/quota_fetchers.go @@ -1,6 +1,7 @@ package management import ( + "errors" "net/http" "github.com/gin-gonic/gin" @@ -42,6 +43,13 @@ func (h *Handler) GetProviderQuotas(c *gin.Context) { ctx := c.Request.Context() quotas, err := h.quotaManager.FetchProviderQuotas(ctx, provider) if err != nil { + if errors.Is(err, quota.ErrUnknownProvider) { + c.JSON(http.StatusNotFound, gin.H{ + "error": err.Error(), + "known_providers": h.quotaManager.GetKnownProviders(), + }) + return + } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -67,6 +75,13 @@ func (h *Handler) GetAccountQuota(c *gin.Context) { ctx := c.Request.Context() quotaResp, err := h.quotaManager.FetchAccountQuota(ctx, provider, account) if err != nil { + if errors.Is(err, quota.ErrUnknownProvider) { + c.JSON(http.StatusNotFound, gin.H{ + "error": err.Error(), + "known_providers": h.quotaManager.GetKnownProviders(), + }) + return + } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -114,4 +129,4 @@ func (h *Handler) GetSubscriptionInfo(c *gin.Context) { } c.JSON(http.StatusOK, info) -} \ No newline at end of file +} diff --git a/internal/quota/manager.go b/internal/quota/manager.go index 788b6eaa5..607030167 100644 --- a/internal/quota/manager.go +++ b/internal/quota/manager.go @@ -2,6 +2,8 @@ package quota import ( "context" + "errors" + "fmt" "net/http" "strings" "sync" @@ -11,13 +13,17 @@ import ( log "github.com/sirupsen/logrus" ) +// ErrUnknownProvider is returned when a provider is not recognized. +var ErrUnknownProvider = errors.New("unknown provider") + // unsupportedProviders lists providers that don't have quota APIs. +// These are "known" providers but don't support quota fetching. var unsupportedProviders = map[string]bool{ - "claude": true, - "gemini": true, // API key-based Gemini (pay-per-use) - "vertex": true, - "iflow": true, - "qwen": true, + "claude": true, + "gemini": true, // API key-based Gemini (pay-per-use) + "vertex": true, + "iflow": true, + "qwen": true, "aistudio": true, } @@ -92,6 +98,55 @@ func isUnsupportedProvider(provider string) bool { return unsupportedProviders[provider] } +// GetKnownProviders returns a list of all known provider names. +// This includes both supported providers (with fetchers) and unsupported providers. +func (m *Manager) GetKnownProviders() []string { + m.mu.RLock() + defer m.mu.RUnlock() + + providers := make(map[string]bool) + + // Add providers from fetchers + for _, fetcher := range m.fetchers { + for _, p := range fetcher.SupportedProviders() { + providers[strings.ToLower(p)] = true + } + } + + // Add unsupported providers (they're still "known") + for p := range unsupportedProviders { + providers[p] = true + } + + result := make([]string, 0, len(providers)) + for p := range providers { + result = append(result, p) + } + return result +} + +// IsKnownProvider returns true if the provider is recognized. +func (m *Manager) IsKnownProvider(provider string) bool { + provider = strings.ToLower(strings.TrimSpace(provider)) + + // Check if it's an unsupported but known provider + if unsupportedProviders[provider] { + return true + } + + // Check if any fetcher supports this provider + m.mu.RLock() + defer m.mu.RUnlock() + + for _, fetcher := range m.fetchers { + if fetcher.CanFetch(provider) { + return true + } + } + + return false +} + // FetchAllQuotas fetches quota for all connected accounts. func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { m.mu.RLock() @@ -134,12 +189,17 @@ func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { // FetchProviderQuotas fetches quota for all accounts of a specific provider. func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*ProviderQuotaResponse, error) { + provider = strings.ToLower(strings.TrimSpace(provider)) + + // Validate provider is known + if !m.IsKnownProvider(provider) { + return nil, fmt.Errorf("%w: %s", ErrUnknownProvider, provider) + } + m.mu.RLock() authStore := m.authStore m.mu.RUnlock() - provider = strings.ToLower(strings.TrimSpace(provider)) - if authStore == nil { return &ProviderQuotaResponse{ Provider: provider, @@ -177,12 +237,17 @@ func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*Pr // FetchAccountQuota fetches quota for a specific account. func (m *Manager) FetchAccountQuota(ctx context.Context, provider, accountID string) (*AccountQuotaResponse, error) { + provider = strings.ToLower(strings.TrimSpace(provider)) + + // Validate provider is known + if !m.IsKnownProvider(provider) { + return nil, fmt.Errorf("%w: %s", ErrUnknownProvider, provider) + } + m.mu.RLock() authStore := m.authStore m.mu.RUnlock() - provider = strings.ToLower(strings.TrimSpace(provider)) - if authStore == nil { return &AccountQuotaResponse{ Provider: provider, @@ -353,4 +418,4 @@ func (m *Manager) getAccountID(auth *coreauth.Auth) string { } return "unknown" -} \ No newline at end of file +} From 623274f91ee4d964b4785cc86cfe076be633bd7f Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 09:00:08 +0700 Subject: [PATCH 04/17] feat(quota): update comments and improve handling for unsupported providers in quota management --- internal/quota/antigravity.go | 22 +++++++++++----------- internal/quota/manager.go | 2 +- internal/quota/types.go | 9 ++------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/internal/quota/antigravity.go b/internal/quota/antigravity.go index d0e35cb22..060680576 100644 --- a/internal/quota/antigravity.go +++ b/internal/quota/antigravity.go @@ -14,11 +14,11 @@ import ( ) const ( - antigravityQuotaAPIURL = "https://cloudcode-pa.googleapis.com/v1internal:fetchAvailableModels" - antigravityLoadProjectURL = "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" - antigravityAPIUserAgent = "antigravity/1.11.3 Darwin/arm64" - antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" - antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` + antigravityQuotaAPIURL = "https://cloudcode-pa.googleapis.com/v1internal:fetchAvailableModels" + antigravityLoadProjectURL = "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" + antigravityAPIUserAgent = "antigravity/1.11.3 Darwin/arm64" + antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" + antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` ) // AntigravityFetcher implements quota fetching for Antigravity and Gemini-CLI providers. @@ -226,11 +226,11 @@ func (f *AntigravityFetcher) setRequestHeaders(req *http.Request, accessToken st type fetchAvailableModelsResponse struct { // Models is a map from model name to model info (not an array!) - Models map[string]modelInfo `json:"models"` - CurrentTier *tierInfo `json:"currentTier"` - AvailableTiers []tierInfo `json:"availableTiers"` - ProjectID string `json:"cloudaicompanionProject"` - TierUpgradeURL string `json:"tierUpgradeUrl"` + Models map[string]modelInfo `json:"models"` + CurrentTier *tierInfo `json:"currentTier"` + AvailableTiers []tierInfo `json:"availableTiers"` + ProjectID string `json:"cloudaicompanionProject"` + TierUpgradeURL string `json:"tierUpgradeUrl"` } type modelInfo struct { @@ -342,4 +342,4 @@ func (f *AntigravityFetcher) GetSubscriptionInfo(ctx context.Context, auth *core } return info, nil -} \ No newline at end of file +} diff --git a/internal/quota/manager.go b/internal/quota/manager.go index 607030167..9fd863235 100644 --- a/internal/quota/manager.go +++ b/internal/quota/manager.go @@ -20,7 +20,7 @@ var ErrUnknownProvider = errors.New("unknown provider") // These are "known" providers but don't support quota fetching. var unsupportedProviders = map[string]bool{ "claude": true, - "gemini": true, // API key-based Gemini (pay-per-use) + "gemini": true, // Gemini CLI doesn't have a public quota API "vertex": true, "iflow": true, "qwen": true, diff --git a/internal/quota/types.go b/internal/quota/types.go index 04fe27d01..1cc67ec95 100644 --- a/internal/quota/types.go +++ b/internal/quota/types.go @@ -123,17 +123,12 @@ type RefreshRequest struct { } // UnavailableQuota returns a ProviderQuotaData indicating quota is not available. +// This is used for providers that don't have a public quota API. func UnavailableQuota(reason string) *ProviderQuotaData { return &ProviderQuotaData{ - Models: []ModelQuota{ - { - Name: "quota", - Percentage: -1, - }, - }, + // No models/windows - quota info is not available for this provider LastUpdated: time.Now(), IsForbidden: false, - PlanType: "unavailable", Error: reason, } } From a66c580c0900dc39b2f6ab41bdd84b26d068751b Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 09:24:32 +0700 Subject: [PATCH 05/17] feat(quota): implement background quota refresh mechanism and related configuration --- config.example.yaml | 7 +- internal/api/handlers/management/handler.go | 23 ++++ internal/api/server.go | 19 ++++ internal/config/config.go | 6 ++ internal/quota/manager.go | 67 ++++++++++++ internal/quota/worker.go | 114 ++++++++++++++++++++ sdk/cliproxy/service.go | 5 + 7 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 internal/quota/worker.go diff --git a/config.example.yaml b/config.example.yaml index 85e006503..e11580bbd 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -52,6 +52,12 @@ logs-max-total-size-mb: 0 # When false, disable in-memory usage statistics aggregation usage-statistics-enabled: false +# Quota refresh interval in seconds. When set to a positive value, the server +# will periodically fetch quota data for all configured providers in the background +# and cache it in memory. This eliminates the need to manually refresh quota data. +# Set to 0 to disable background refresh (fetch on-demand only). Default: 0 +# quota-refresh-interval: 300 # 5 minutes + # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ proxy-url: "" @@ -75,7 +81,6 @@ routing: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false - # Streaming behavior (SSE keep-alives + safe bootstrap retries). # streaming: # keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives. diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 3a0f3d99b..9206ef374 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -3,6 +3,7 @@ package management import ( + "context" "crypto/subtle" "fmt" "net/http" @@ -52,6 +53,12 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man tokenStore := sdkAuth.GetTokenStore() quotaManager := quota.NewManager(tokenStore, nil) + // Configure quota refresh interval if set in config + if cfg != nil && cfg.QuotaRefreshInterval > 0 { + interval := time.Duration(cfg.QuotaRefreshInterval) * time.Second + quotaManager.SetRefreshInterval(interval) + } + return &Handler{ cfg: cfg, configFilePath: configFilePath, @@ -95,6 +102,22 @@ func (h *Handler) SetLogDirectory(dir string) { h.logDir = dir } +// StartBackgroundWorkers starts background workers (quota refresh, etc.) +// This should be called when the server is ready to handle requests. +func (h *Handler) StartBackgroundWorkers(ctx context.Context) { + if h.quotaManager != nil { + h.quotaManager.StartWorker(ctx) + } +} + +// StopBackgroundWorkers stops all background workers. +// This should be called during server shutdown. +func (h *Handler) StopBackgroundWorkers() { + if h.quotaManager != nil { + h.quotaManager.StopWorker() + } +} + // Middleware enforces access control for management endpoints. // All requests (local and remote) require a valid management key. // Additionally, remote access requires allow-remote-management=true. diff --git a/internal/api/server.go b/internal/api/server.go index b9790d49c..99623e764 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -771,6 +771,22 @@ func (s *Server) Start() error { return nil } +// StartBackgroundWorkers starts background workers such as quota refresh. +// This should be called before Start() with the run context. +func (s *Server) StartBackgroundWorkers(ctx context.Context) { + if s.mgmt != nil { + s.mgmt.StartBackgroundWorkers(ctx) + } +} + +// StopBackgroundWorkers stops all background workers. +// This is called automatically by Stop() but can be called manually if needed. +func (s *Server) StopBackgroundWorkers() { + if s.mgmt != nil { + s.mgmt.StopBackgroundWorkers() + } +} + // Stop gracefully shuts down the API server without interrupting any // active connections. // @@ -782,6 +798,9 @@ func (s *Server) Start() error { func (s *Server) Stop(ctx context.Context) error { log.Debug("Stopping API server...") + // Stop background workers first + s.StopBackgroundWorkers() + if s.keepAliveEnabled { select { case s.keepAliveStop <- struct{}{}: diff --git a/internal/config/config.go b/internal/config/config.go index dea56dff9..b92655a3b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -55,6 +55,12 @@ type Config struct { // DisableCooling disables quota cooldown scheduling when true. DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"` + // QuotaRefreshInterval is the interval in seconds for background quota refresh. + // When set to a positive value, the server will periodically fetch quota data + // for all configured providers in the background and cache it in memory. + // Set to 0 to disable background refresh (fetch on-demand only). Default: 0 + QuotaRefreshInterval int `yaml:"quota-refresh-interval" json:"quota-refresh-interval"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. diff --git a/internal/quota/manager.go b/internal/quota/manager.go index 9fd863235..6bd08ddce 100644 --- a/internal/quota/manager.go +++ b/internal/quota/manager.go @@ -34,6 +34,7 @@ type Manager struct { cache *QuotaCache authStore coreauth.Store httpClient *http.Client + worker *Worker } // NewManager creates a new quota manager with the given auth store. @@ -78,6 +79,72 @@ func (m *Manager) SetAuthStore(store coreauth.Store) { m.authStore = store } +// SetRefreshInterval configures the background refresh interval. +// If interval is > 0, creates a new worker that will be started when StartWorker is called. +// If interval is <= 0, stops any existing worker and disables background refresh. +func (m *Manager) SetRefreshInterval(interval time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + + // Stop existing worker if running + if m.worker != nil { + m.worker.Stop() + m.worker = nil + } + + // Create new worker if interval is positive + if interval > 0 { + m.worker = NewWorker(m, interval) + } +} + +// StartWorker starts the background quota refresh worker if configured. +// This should be called after the server is ready to handle requests. +func (m *Manager) StartWorker(ctx context.Context) { + m.mu.RLock() + worker := m.worker + m.mu.RUnlock() + + if worker != nil { + worker.Start(ctx) + } +} + +// StopWorker stops the background quota refresh worker. +// This should be called during server shutdown. +func (m *Manager) StopWorker() { + m.mu.RLock() + worker := m.worker + m.mu.RUnlock() + + if worker != nil { + worker.Stop() + } +} + +// WorkerStatus returns information about the background worker. +// Returns nil if no worker is configured. +func (m *Manager) WorkerStatus() *WorkerStatus { + m.mu.RLock() + worker := m.worker + m.mu.RUnlock() + + if worker == nil { + return nil + } + + return &WorkerStatus{ + Running: worker.IsRunning(), + Interval: worker.Interval(), + } +} + +// WorkerStatus contains information about the quota refresh worker. +type WorkerStatus struct { + Running bool `json:"running"` + Interval time.Duration `json:"interval"` +} + // getFetcherForProvider returns the fetcher that can handle the given provider. func (m *Manager) getFetcherForProvider(provider string) Fetcher { m.mu.RLock() diff --git a/internal/quota/worker.go b/internal/quota/worker.go new file mode 100644 index 000000000..f8701f54e --- /dev/null +++ b/internal/quota/worker.go @@ -0,0 +1,114 @@ +package quota + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// Worker handles periodic quota refresh in the background. +// It uses a time.Ticker to periodically fetch quota data for all configured +// providers and cache it in memory. +type Worker struct { + manager *Manager + interval time.Duration + mu sync.Mutex + running bool + stopCh chan struct{} +} + +// NewWorker creates a new quota refresh worker. +// The interval specifies how often the worker should refresh quota data. +// If interval is <= 0, the worker will not start when Start is called. +func NewWorker(manager *Manager, interval time.Duration) *Worker { + return &Worker{ + manager: manager, + interval: interval, + } +} + +// Start begins the background refresh loop. +// Returns immediately if interval is <= 0 or if the worker is already running. +// The worker will perform an initial fetch immediately, then continue at the +// configured interval until Stop is called or the context is cancelled. +func (w *Worker) Start(ctx context.Context) { + w.mu.Lock() + if w.interval <= 0 || w.running { + w.mu.Unlock() + return + } + w.running = true + w.stopCh = make(chan struct{}) + w.mu.Unlock() + + log.Infof("quota worker: starting with interval %v", w.interval) + go w.run(ctx) +} + +// Stop halts the background refresh loop. +// This method is safe to call multiple times or if the worker is not running. +func (w *Worker) Stop() { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.running { + return + } + + close(w.stopCh) + w.running = false +} + +// IsRunning returns true if the worker is currently running. +func (w *Worker) IsRunning() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.running +} + +// Interval returns the configured refresh interval. +func (w *Worker) Interval() time.Duration { + return w.interval +} + +func (w *Worker) run(ctx context.Context) { + ticker := time.NewTicker(w.interval) + defer ticker.Stop() + + // Initial fetch on startup + w.refresh(ctx) + + for { + select { + case <-ctx.Done(): + log.Info("quota worker: shutting down due to context cancellation") + w.mu.Lock() + w.running = false + w.mu.Unlock() + return + case <-w.stopCh: + log.Info("quota worker: stopped") + return + case <-ticker.C: + w.refresh(ctx) + } + } +} + +func (w *Worker) refresh(ctx context.Context) { + log.Debug("quota worker: starting refresh") + + // Use a timeout context to prevent hanging on slow API calls + refreshCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + _, err := w.manager.FetchAllQuotas(refreshCtx) + if err != nil { + log.Warnf("quota worker: refresh failed: %v", err) + return + } + + log.Debug("quota worker: refresh completed") +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6e81e4016..c12ea83be 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -500,6 +500,11 @@ func (s *Service) Run(ctx context.Context) error { time.Sleep(100 * time.Millisecond) fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port) + // Start background workers (quota refresh, etc.) after server is running + if s.server != nil { + s.server.StartBackgroundWorkers(ctx) + } + if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) } From 11b08eca1b2d224dadefb42fc7c93561633d11da Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 09:36:09 +0700 Subject: [PATCH 06/17] feat(quota): implement access token management with refresh functionality for Antigravity and Codex providers --- internal/quota/antigravity.go | 158 ++++++++++++++++++++++++++++++++- internal/quota/codex.go | 160 +++++++++++++++++++++++++++++++++- 2 files changed, 314 insertions(+), 4 deletions(-) diff --git a/internal/quota/antigravity.go b/internal/quota/antigravity.go index 060680576..e572c6eca 100644 --- a/internal/quota/antigravity.go +++ b/internal/quota/antigravity.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "time" @@ -19,6 +20,11 @@ const ( antigravityAPIUserAgent = "antigravity/1.11.3 Darwin/arm64" antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` + // OAuth credentials for token refresh + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + // refreshSkew is the time buffer before token expiry to trigger refresh + refreshSkew = 5 * time.Minute ) // AntigravityFetcher implements quota fetching for Antigravity and Gemini-CLI providers. @@ -57,8 +63,12 @@ func (f *AntigravityFetcher) FetchQuota(ctx context.Context, auth *coreauth.Auth return nil, fmt.Errorf("auth is nil") } - // Get access token from metadata - accessToken := f.extractAccessToken(auth) + // Ensure we have a valid access token (refresh if expired) + accessToken, err := f.ensureAccessToken(ctx, auth) + if err != nil { + log.Warnf("antigravity quota: failed to ensure access token: %v", err) + return UnavailableQuota(fmt.Sprintf("failed to ensure access token: %v", err)), nil + } if accessToken == "" { return UnavailableQuota("no access token available"), nil } @@ -85,6 +95,45 @@ func (f *AntigravityFetcher) FetchQuota(ctx context.Context, auth *coreauth.Auth return quotaData, nil } +// ensureAccessToken ensures we have a valid access token, refreshing if expired. +func (f *AntigravityFetcher) ensureAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) { + if auth.Metadata == nil { + return "", nil + } + + accessToken := f.extractAccessToken(auth) + if accessToken == "" { + return "", nil + } + + // Check if token is expired or about to expire + expiry := f.tokenExpiry(auth) + if expiry.After(time.Now().Add(refreshSkew)) { + // Token is still valid + return accessToken, nil + } + + // Token is expired or about to expire, try to refresh + log.Debugf("antigravity quota: access token expired or expiring soon, attempting refresh") + + refreshToken := f.extractRefreshToken(auth) + if refreshToken == "" { + // No refresh token, return existing access token (it may still work) + log.Debugf("antigravity quota: no refresh token available, using existing access token") + return accessToken, nil + } + + // Refresh the token + newToken, err := f.refreshAccessToken(ctx, auth, refreshToken) + if err != nil { + log.Warnf("antigravity quota: failed to refresh token: %v", err) + // Return existing token anyway - API might still accept it + return accessToken, nil + } + + return newToken, nil +} + // extractAccessToken extracts the access token from auth metadata. func (f *AntigravityFetcher) extractAccessToken(auth *coreauth.Auth) string { if auth.Metadata == nil { @@ -96,6 +145,111 @@ func (f *AntigravityFetcher) extractAccessToken(auth *coreauth.Auth) string { return "" } +// extractRefreshToken extracts the refresh token from auth metadata. +func (f *AntigravityFetcher) extractRefreshToken(auth *coreauth.Auth) string { + if auth.Metadata == nil { + return "" + } + if token, ok := auth.Metadata["refresh_token"].(string); ok { + return strings.TrimSpace(token) + } + return "" +} + +// tokenExpiry extracts the token expiry time from auth metadata. +func (f *AntigravityFetcher) tokenExpiry(auth *coreauth.Auth) time.Time { + if auth.Metadata == nil { + return time.Time{} + } + + // Check various expiry field names + var expiryStr string + if v, ok := auth.Metadata["expired"].(string); ok { + expiryStr = v + } else if v, ok := auth.Metadata["expires_at"].(string); ok { + expiryStr = v + } else if v, ok := auth.Metadata["expiry"].(string); ok { + expiryStr = v + } + + if expiryStr == "" { + return time.Time{} + } + + // Try parsing with various formats + formats := []string{ + time.RFC3339, + time.RFC3339Nano, + "2006-01-02T15:04:05-07:00", + "2006-01-02T15:04:05Z07:00", + } + for _, format := range formats { + if t, err := time.Parse(format, expiryStr); err == nil { + return t + } + } + + return time.Time{} +} + +// refreshAccessToken refreshes the access token using the refresh token. +func (f *AntigravityFetcher) refreshAccessToken(ctx context.Context, auth *coreauth.Auth, refreshToken string) (string, error) { + form := url.Values{} + form.Set("client_id", antigravityClientID) + form.Set("client_secret", antigravityClientSecret) + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", refreshToken) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) + if err != nil { + return "", fmt.Errorf("create refresh request: %w", err) + } + + req.Header.Set("Host", "oauth2.googleapis.com") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := f.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute refresh request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("refresh failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("decode refresh response: %w", err) + } + + // Update auth metadata with new tokens + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["access_token"] = tokenResp.AccessToken + if tokenResp.RefreshToken != "" { + auth.Metadata["refresh_token"] = tokenResp.RefreshToken + } + auth.Metadata["expires_in"] = tokenResp.ExpiresIn + auth.Metadata["expired"] = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339) + auth.Metadata["type"] = "antigravity" + + log.Debugf("antigravity quota: token refreshed successfully, new expiry: %s", auth.Metadata["expired"]) + + return tokenResp.AccessToken, nil +} + // extractProjectID extracts the project ID from auth metadata. func (f *AntigravityFetcher) extractProjectID(auth *coreauth.Auth) string { if auth.Metadata == nil { diff --git a/internal/quota/codex.go b/internal/quota/codex.go index 314b84406..9059ebe50 100644 --- a/internal/quota/codex.go +++ b/internal/quota/codex.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "time" @@ -15,6 +16,10 @@ import ( const ( codexUsageAPIURL = "https://chatgpt.com/backend-api/wham/usage" + codexTokenURL = "https://auth.openai.com/oauth/token" + codexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + // codexRefreshSkew is the time buffer before token expiry to trigger refresh + codexRefreshSkew = 5 * time.Minute ) // CodexFetcher implements quota fetching for Codex/OpenAI provider. @@ -52,8 +57,12 @@ func (f *CodexFetcher) FetchQuota(ctx context.Context, auth *coreauth.Auth) (*Pr return nil, fmt.Errorf("auth is nil") } - // Get access token from metadata - accessToken := f.extractAccessToken(auth) + // Ensure we have a valid access token (refresh if expired) + accessToken, err := f.ensureAccessToken(ctx, auth) + if err != nil { + log.Warnf("codex quota: failed to ensure access token: %v", err) + return UnavailableQuota(fmt.Sprintf("failed to ensure access token: %v", err)), nil + } if accessToken == "" { return UnavailableQuota("no access token available"), nil } @@ -68,6 +77,45 @@ func (f *CodexFetcher) FetchQuota(ctx context.Context, auth *coreauth.Auth) (*Pr return quotaData, nil } +// ensureAccessToken ensures we have a valid access token, refreshing if expired. +func (f *CodexFetcher) ensureAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) { + if auth.Metadata == nil { + return "", nil + } + + accessToken := f.extractAccessToken(auth) + if accessToken == "" { + return "", nil + } + + // Check if token is expired or about to expire + expiry := f.tokenExpiry(auth) + if expiry.After(time.Now().Add(codexRefreshSkew)) { + // Token is still valid + return accessToken, nil + } + + // Token is expired or about to expire, try to refresh + log.Debugf("codex quota: access token expired or expiring soon, attempting refresh") + + refreshToken := f.extractRefreshToken(auth) + if refreshToken == "" { + // No refresh token, return existing access token (it may still work) + log.Debugf("codex quota: no refresh token available, using existing access token") + return accessToken, nil + } + + // Refresh the token + newToken, err := f.refreshAccessToken(ctx, auth, refreshToken) + if err != nil { + log.Warnf("codex quota: failed to refresh token: %v", err) + // Return existing token anyway - API might still accept it + return accessToken, nil + } + + return newToken, nil +} + // extractAccessToken extracts the access token from auth metadata. func (f *CodexFetcher) extractAccessToken(auth *coreauth.Auth) string { if auth.Metadata == nil { @@ -79,6 +127,114 @@ func (f *CodexFetcher) extractAccessToken(auth *coreauth.Auth) string { return "" } +// extractRefreshToken extracts the refresh token from auth metadata. +func (f *CodexFetcher) extractRefreshToken(auth *coreauth.Auth) string { + if auth.Metadata == nil { + return "" + } + if token, ok := auth.Metadata["refresh_token"].(string); ok { + return strings.TrimSpace(token) + } + return "" +} + +// tokenExpiry extracts the token expiry time from auth metadata. +func (f *CodexFetcher) tokenExpiry(auth *coreauth.Auth) time.Time { + if auth.Metadata == nil { + return time.Time{} + } + + // Check various expiry field names + var expiryStr string + if v, ok := auth.Metadata["expired"].(string); ok { + expiryStr = v + } else if v, ok := auth.Metadata["expires_at"].(string); ok { + expiryStr = v + } else if v, ok := auth.Metadata["expiry"].(string); ok { + expiryStr = v + } + + if expiryStr == "" { + return time.Time{} + } + + // Try parsing with various formats + formats := []string{ + time.RFC3339, + time.RFC3339Nano, + "2006-01-02T15:04:05-07:00", + "2006-01-02T15:04:05Z07:00", + } + for _, format := range formats { + if t, err := time.Parse(format, expiryStr); err == nil { + return t + } + } + + return time.Time{} +} + +// refreshAccessToken refreshes the access token using the refresh token. +func (f *CodexFetcher) refreshAccessToken(ctx context.Context, auth *coreauth.Auth, refreshToken string) (string, error) { + form := url.Values{} + form.Set("client_id", codexClientID) + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", refreshToken) + form.Set("scope", "openid profile email") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexTokenURL, strings.NewReader(form.Encode())) + if err != nil { + return "", fmt.Errorf("create refresh request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := f.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute refresh request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("refresh failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + ExpiresIn int64 `json:"expires_in"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("decode refresh response: %w", err) + } + + // Update auth metadata with new tokens + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["access_token"] = tokenResp.AccessToken + if tokenResp.RefreshToken != "" { + auth.Metadata["refresh_token"] = tokenResp.RefreshToken + } + if tokenResp.IDToken != "" { + auth.Metadata["id_token"] = tokenResp.IDToken + } + auth.Metadata["expires_in"] = tokenResp.ExpiresIn + auth.Metadata["expired"] = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339) + auth.Metadata["type"] = "codex" + + log.Debugf("codex quota: token refreshed successfully, new expiry: %s", auth.Metadata["expired"]) + + return tokenResp.AccessToken, nil +} + // fetchUsageData fetches usage data from the ChatGPT API. func (f *CodexFetcher) fetchUsageData(ctx context.Context, accessToken string) (*ProviderQuotaData, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, codexUsageAPIURL, nil) From 369f52d940f51e9e1dae4d5af24fca8d967cc89d Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 09:58:13 +0700 Subject: [PATCH 07/17] feat(quota): enhance refresh interval handling and update LastUpdated timestamps for quota responses --- internal/quota/manager.go | 51 ++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/internal/quota/manager.go b/internal/quota/manager.go index 6bd08ddce..ed0db8b0a 100644 --- a/internal/quota/manager.go +++ b/internal/quota/manager.go @@ -81,7 +81,8 @@ func (m *Manager) SetAuthStore(store coreauth.Store) { // SetRefreshInterval configures the background refresh interval. // If interval is > 0, creates a new worker that will be started when StartWorker is called. -// If interval is <= 0, stops any existing worker and disables background refresh. +// The cache TTL is also set to match the interval so cache stays valid between worker ticks. +// If interval is <= 0, stops any existing worker, disables background refresh, and resets cache TTL to default. func (m *Manager) SetRefreshInterval(interval time.Duration) { m.mu.Lock() defer m.mu.Unlock() @@ -95,6 +96,12 @@ func (m *Manager) SetRefreshInterval(interval time.Duration) { // Create new worker if interval is positive if interval > 0 { m.worker = NewWorker(m, interval) + // Set cache TTL to match worker interval so cache stays valid between ticks + // Add a small buffer (10%) to avoid race conditions at interval boundaries + m.cache.SetTTL(interval + interval/10) + } else { + // Reset cache TTL to default when worker is disabled + m.cache.SetTTL(DefaultCacheTTL) } } @@ -234,10 +241,12 @@ func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { } response := &QuotaResponse{ - Quotas: make(map[string]map[string]*ProviderQuotaData), - LastUpdated: time.Now(), + Quotas: make(map[string]map[string]*ProviderQuotaData), } + // Track the earliest LastUpdated from all quota data + var earliestUpdate time.Time + // Group auths by provider for _, auth := range auths { provider := strings.ToLower(strings.TrimSpace(auth.Provider)) @@ -249,6 +258,20 @@ func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { quotaData := m.fetchQuotaForAuth(ctx, auth, false) response.Quotas[provider][accountID] = quotaData + + // Track earliest update time from actual data + if !quotaData.LastUpdated.IsZero() { + if earliestUpdate.IsZero() || quotaData.LastUpdated.Before(earliestUpdate) { + earliestUpdate = quotaData.LastUpdated + } + } + } + + // Use earliest data timestamp, or current time if no data + if earliestUpdate.IsZero() { + response.LastUpdated = time.Now() + } else { + response.LastUpdated = earliestUpdate } return response, nil @@ -282,11 +305,13 @@ func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*Pr } response := &ProviderQuotaResponse{ - Provider: provider, - Accounts: make(map[string]*ProviderQuotaData), - LastUpdated: time.Now(), + Provider: provider, + Accounts: make(map[string]*ProviderQuotaData), } + // Track the earliest LastUpdated from all quota data + var earliestUpdate time.Time + // Filter and fetch for this provider for _, auth := range auths { authProvider := strings.ToLower(strings.TrimSpace(auth.Provider)) @@ -297,6 +322,20 @@ func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*Pr accountID := m.getAccountID(auth) quotaData := m.fetchQuotaForAuth(ctx, auth, false) response.Accounts[accountID] = quotaData + + // Track earliest update time from actual data + if !quotaData.LastUpdated.IsZero() { + if earliestUpdate.IsZero() || quotaData.LastUpdated.Before(earliestUpdate) { + earliestUpdate = quotaData.LastUpdated + } + } + } + + // Use earliest data timestamp, or current time if no data + if earliestUpdate.IsZero() { + response.LastUpdated = time.Now() + } else { + response.LastUpdated = earliestUpdate } return response, nil From 173a01fd2b2bebc805f8c89a304d99869c4b8686 Mon Sep 17 00:00:00 2001 From: evann Date: Mon, 29 Dec 2025 10:55:50 +0700 Subject: [PATCH 08/17] feat(quota): improve quota refresh handling and update LastUpdated logic to track latest timestamps --- .../api/handlers/management/quota_fetchers.go | 8 +++-- internal/quota/cache.go | 4 +-- internal/quota/manager.go | 32 +++++++++---------- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/internal/api/handlers/management/quota_fetchers.go b/internal/api/handlers/management/quota_fetchers.go index ecc5437d3..e7d4a7fc6 100644 --- a/internal/api/handlers/management/quota_fetchers.go +++ b/internal/api/handlers/management/quota_fetchers.go @@ -98,9 +98,11 @@ func (h *Handler) RefreshQuotas(c *gin.Context) { } var req quota.RefreshRequest - if err := c.ShouldBindJSON(&req); err != nil { - // Allow empty body - refresh all - req.Providers = nil + if c.Request.ContentLength > 0 { + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body: " + err.Error()}) + return + } } ctx := c.Request.Context() diff --git a/internal/quota/cache.go b/internal/quota/cache.go index 94c43e7f5..f78769087 100644 --- a/internal/quota/cache.go +++ b/internal/quota/cache.go @@ -83,7 +83,7 @@ func (c *QuotaCache) InvalidateProvider(provider string) { prefix := provider + ":" for key := range c.entries { - if len(key) > len(prefix) && key[:len(prefix)] == prefix { + if len(key) >= len(prefix) && key[:len(prefix)] == prefix { delete(c.entries, key) } } @@ -116,4 +116,4 @@ func (c *QuotaCache) SetTTL(ttl time.Duration) { if ttl > 0 { c.ttl = ttl } -} \ No newline at end of file +} diff --git a/internal/quota/manager.go b/internal/quota/manager.go index ed0db8b0a..b59372a2c 100644 --- a/internal/quota/manager.go +++ b/internal/quota/manager.go @@ -244,8 +244,8 @@ func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { Quotas: make(map[string]map[string]*ProviderQuotaData), } - // Track the earliest LastUpdated from all quota data - var earliestUpdate time.Time + // Track the latest LastUpdated from all quota data + var latestUpdate time.Time // Group auths by provider for _, auth := range auths { @@ -259,19 +259,19 @@ func (m *Manager) FetchAllQuotas(ctx context.Context) (*QuotaResponse, error) { quotaData := m.fetchQuotaForAuth(ctx, auth, false) response.Quotas[provider][accountID] = quotaData - // Track earliest update time from actual data + // Track latest update time from actual data if !quotaData.LastUpdated.IsZero() { - if earliestUpdate.IsZero() || quotaData.LastUpdated.Before(earliestUpdate) { - earliestUpdate = quotaData.LastUpdated + if latestUpdate.IsZero() || quotaData.LastUpdated.After(latestUpdate) { + latestUpdate = quotaData.LastUpdated } } } - // Use earliest data timestamp, or current time if no data - if earliestUpdate.IsZero() { + // Use latest data timestamp, or current time if no data + if latestUpdate.IsZero() { response.LastUpdated = time.Now() } else { - response.LastUpdated = earliestUpdate + response.LastUpdated = latestUpdate } return response, nil @@ -309,8 +309,8 @@ func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*Pr Accounts: make(map[string]*ProviderQuotaData), } - // Track the earliest LastUpdated from all quota data - var earliestUpdate time.Time + // Track the latest LastUpdated from all quota data + var latestUpdate time.Time // Filter and fetch for this provider for _, auth := range auths { @@ -323,19 +323,19 @@ func (m *Manager) FetchProviderQuotas(ctx context.Context, provider string) (*Pr quotaData := m.fetchQuotaForAuth(ctx, auth, false) response.Accounts[accountID] = quotaData - // Track earliest update time from actual data + // Track latest update time from actual data if !quotaData.LastUpdated.IsZero() { - if earliestUpdate.IsZero() || quotaData.LastUpdated.Before(earliestUpdate) { - earliestUpdate = quotaData.LastUpdated + if latestUpdate.IsZero() || quotaData.LastUpdated.After(latestUpdate) { + latestUpdate = quotaData.LastUpdated } } } - // Use earliest data timestamp, or current time if no data - if earliestUpdate.IsZero() { + // Use latest data timestamp, or current time if no data + if latestUpdate.IsZero() { response.LastUpdated = time.Now() } else { - response.LastUpdated = earliestUpdate + response.LastUpdated = latestUpdate } return response, nil From d3fe406a1c83e06bbc3e2434ae20d7471880f138 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 28 Dec 2025 19:04:31 +0800 Subject: [PATCH 09/17] fix(logging): improve request/response capture --- sdk/api/handlers/handlers.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 86ed92767..840fd8940 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -618,7 +618,23 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro } body := BuildErrorResponseBody(status, errText) - c.Set("API_RESPONSE", bytes.Clone(body)) + // Check if this error body was already recorded by the executor (to avoid duplicate logging) + // This can happen when the last retry fails and both executor and handler try to log the same error + shouldAppend := true + if existing, exists := c.Get("API_RESPONSE"); exists { + if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { + trimmedBody := bytes.TrimSpace(body) + if len(trimmedBody) > 0 && bytes.Contains(existingBytes, trimmedBody) { + // Error already logged by executor, skip appending + shouldAppend = false + } + } + } + if shouldAppend { + // Use appendAPIResponse to preserve any previously captured API response data + // (such as formatted upstream response logs from logging_helpers.go) + appendAPIResponse(c, body) + } if !c.Writer.Written() { c.Writer.Header().Set("Content-Type", "application/json") From cc2725c1095f55c626b462dd351a82ceef229e08 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 28 Dec 2025 19:35:36 +0800 Subject: [PATCH 10/17] fix(handlers): match raw error text before JSON body for duplicate detection --- sdk/api/handlers/handlers.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 840fd8940..f70a9dca6 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -618,15 +618,21 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro } body := BuildErrorResponseBody(status, errText) - // Check if this error body was already recorded by the executor (to avoid duplicate logging) - // This can happen when the last retry fails and both executor and handler try to log the same error + // Check if this error was already recorded by the executor (to avoid duplicate logging) + // The executor logs raw error text, so match errText before falling back to the JSON body. shouldAppend := true if existing, exists := c.Get("API_RESPONSE"); exists { if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { - trimmedBody := bytes.TrimSpace(body) - if len(trimmedBody) > 0 && bytes.Contains(existingBytes, trimmedBody) { + trimmedErrText := strings.TrimSpace(errText) + if trimmedErrText != "" && bytes.Contains(existingBytes, []byte(trimmedErrText)) { // Error already logged by executor, skip appending shouldAppend = false + } else { + trimmedBody := bytes.TrimSpace(body) + if len(trimmedBody) > 0 && bytes.Contains(existingBytes, trimmedBody) { + // Error already logged by executor, skip appending + shouldAppend = false + } } } } From 5bca7afbb130068c04449ce2b79ba1df7fb6ff1e Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 28 Dec 2025 22:35:36 +0800 Subject: [PATCH 11/17] fix(handlers): preserve upstream response logs before duplicate detection --- sdk/api/handlers/handlers.go | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index f70a9dca6..5a24c63a4 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -618,28 +618,21 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro } body := BuildErrorResponseBody(status, errText) - // Check if this error was already recorded by the executor (to avoid duplicate logging) - // The executor logs raw error text, so match errText before falling back to the JSON body. - shouldAppend := true + // Append first to preserve upstream response logs, then drop duplicate payloads if already recorded. + var previous []byte if existing, exists := c.Get("API_RESPONSE"); exists { if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { - trimmedErrText := strings.TrimSpace(errText) - if trimmedErrText != "" && bytes.Contains(existingBytes, []byte(trimmedErrText)) { - // Error already logged by executor, skip appending - shouldAppend = false - } else { - trimmedBody := bytes.TrimSpace(body) - if len(trimmedBody) > 0 && bytes.Contains(existingBytes, trimmedBody) { - // Error already logged by executor, skip appending - shouldAppend = false - } - } + previous = bytes.Clone(existingBytes) } } - if shouldAppend { - // Use appendAPIResponse to preserve any previously captured API response data - // (such as formatted upstream response logs from logging_helpers.go) - appendAPIResponse(c, body) + appendAPIResponse(c, body) + trimmedErrText := strings.TrimSpace(errText) + trimmedBody := bytes.TrimSpace(body) + if len(previous) > 0 { + if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) || + (len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) { + c.Set("API_RESPONSE", previous) + } } if !c.Writer.Written() { From 9a2e319a6900100a492bc3f0248b0168a6f423cf Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 29 Dec 2025 08:42:29 +0800 Subject: [PATCH 12/17] chore: add codex, agents, and opencode dirs to ignore files --- .dockerignore | 5 ++++- .gitignore | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.dockerignore b/.dockerignore index 8e4e8b1c0..ef021aea0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -23,11 +23,14 @@ config.yaml # Development/editor bin/* -.claude/* .vscode/* +.claude/* +.codex/* .gemini/* .serena/* .agent/* +.agents/* +.opencode/* .bmad/* _bmad/* _bmad-output/* diff --git a/.gitignore b/.gitignore index 5cfea71ed..183138f96 100644 --- a/.gitignore +++ b/.gitignore @@ -33,10 +33,14 @@ GEMINI.md # Tooling metadata .vscode/* +.codex/* .claude/* .gemini/* .serena/* .agent/* +.agents/* +.agents/* +.opencode/* .bmad/* _bmad/* _bmad-output/* From 8efb91184bfd818b25e52f7d01d987dbfb931f96 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 29 Dec 2025 11:54:26 +0800 Subject: [PATCH 13/17] fix(translators): correct key path for `system_instruction.parts` in Claude request logic --- internal/translator/gemini/claude/gemini_claude_request.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index 4ab6ab970..c410aad80 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -56,7 +56,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction) } } else if systemResult.Type == gjson.String { - out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String()) + out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String()) } // contents From 74bf4a658d62f8b1b66e1cc77e73dbb57a44b1c1 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:26:25 +0800 Subject: [PATCH 14/17] feat(amp): add per-client upstream API key mapping support --- config.example.yaml | 13 ++ .../api/handlers/management/config_lists.go | 148 +++++++++++++++++ internal/api/modules/amp/amp.go | 89 +++++++++- internal/api/modules/amp/amp_test.go | 38 +++++ internal/api/modules/amp/proxy.go | 35 ++++ internal/api/modules/amp/proxy_test.go | 157 ++++++++++++++++++ internal/api/modules/amp/routes.go | 39 +++++ internal/api/modules/amp/secret.go | 82 +++++++++ internal/api/modules/amp/secret_test.go | 86 ++++++++++ internal/api/server.go | 4 + internal/config/config.go | 16 ++ internal/util/gemini_schema_test.go | 65 -------- internal/watcher/diff/config_diff.go | 45 +++++ test/amp_management_test.go | 88 ++++++++++ 14 files changed, 836 insertions(+), 69 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index e11580bbd..730a6b3fd 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -35,6 +35,7 @@ auth-dir: "~/.cli-proxy-api" api-keys: - "your-api-key-1" - "your-api-key-2" + - "your-api-key-3" # Enable debug logging debug: false @@ -171,6 +172,18 @@ ws-auth: false # upstream-url: "https://ampcode.com" # # Optional: Override API key for Amp upstream (otherwise uses env or file) # upstream-api-key: "" +# # Per-client upstream API key mapping +# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys. +# # Useful when different clients need to use different Amp accounts/quotas. +# # If a client key isn't mapped, falls back to upstream-api-key (default behavior). +# upstream-api-keys: +# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients +# api-keys: # Client keys that use this upstream key +# - "your-api-key-1" +# - "your-api-key-2" +# - upstream-api-key: "amp_key_for_team_b" +# api-keys: +# - "your-api-key-3" # # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false) # restrict-management-to-localhost: false # # Force model mappings to run before checking local API keys (default: false) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index cc99ce3a0..e3636fd83 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -940,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) } + +// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping. +func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}}) + return + } + c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys}) +} + +// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings. +func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) { + var body struct { + Value []config.AmpUpstreamAPIKeyEntry `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + // Normalize entries: trim whitespace, filter empty + normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value) + h.cfg.AmpCode.UpstreamAPIKeys = normalized + h.persist(c) +} + +// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries. +// Matching is done by upstream-api-key value. +func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) { + var body struct { + Value []config.AmpUpstreamAPIKeyEntry `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + + existing := make(map[string]int) + for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys { + existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i + } + + for _, newEntry := range body.Value { + upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey) + if upstreamKey == "" { + continue + } + normalizedEntry := config.AmpUpstreamAPIKeyEntry{ + UpstreamAPIKey: upstreamKey, + APIKeys: normalizeAPIKeysList(newEntry.APIKeys), + } + if idx, ok := existing[upstreamKey]; ok { + h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry + } else { + h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry) + existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1 + } + } + h.persist(c) +} + +// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries. +// Body must be JSON: {"value": ["", ...]}. +// If "value" is an empty array, clears all entries. +// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change. +func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + + if body.Value == nil { + c.JSON(400, gin.H{"error": "missing value"}) + return + } + + // Empty array means clear all + if len(body.Value) == 0 { + h.cfg.AmpCode.UpstreamAPIKeys = nil + h.persist(c) + return + } + + toRemove := make(map[string]bool) + for _, key := range body.Value { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + toRemove[trimmed] = true + } + if len(toRemove) == 0 { + c.JSON(400, gin.H{"error": "empty value"}) + return + } + + newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys)) + for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys { + if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] { + newEntries = append(newEntries, entry) + } + } + h.cfg.AmpCode.UpstreamAPIKeys = newEntries + h.persist(c) +} + +// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries. +func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry { + if len(entries) == 0 { + return nil + } + out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries)) + for _, entry := range entries { + upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey) + if upstreamKey == "" { + continue + } + apiKeys := normalizeAPIKeysList(entry.APIKeys) + out = append(out, config.AmpUpstreamAPIKeyEntry{ + UpstreamAPIKey: upstreamKey, + APIKeys: apiKeys, + }) + } + if len(out) == 0 { + return nil + } + return out +} + +// normalizeAPIKeysList trims and filters empty strings from a list of API keys. +func normalizeAPIKeysList(keys []string) []string { + if len(keys) == 0 { + return nil + } + out := make([]string, 0, len(keys)) + for _, k := range keys { + trimmed := strings.TrimSpace(k) + if trimmed != "" { + out = append(out, trimmed) + } + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index 924b34529..b5626ce9c 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { } } - // Check API key change + // Check API key change (both default and per-client mappings) apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings) - if apiKeyChanged { + upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings) + if apiKeyChanged || upstreamAPIKeysChanged { if m.secretSource != nil { - if ms, ok := m.secretSource.(*MultiSourceSecret); ok { + if ms, ok := m.secretSource.(*MappedSecretSource); ok { + if apiKeyChanged { + ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey) + ms.InvalidateCache() + } + if upstreamAPIKeysChanged { + ms.UpdateMappings(newSettings.UpstreamAPIKeys) + } + } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { ms.UpdateExplicitKey(newSettings.UpstreamAPIKey) ms.InvalidateCache() } @@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error { if m.secretSource == nil { - m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */) + // Create MultiSourceSecret as the default source, then wrap with MappedSecretSource + defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */) + mappedSource := NewMappedSecretSource(defaultSource) + mappedSource.UpdateMappings(settings.UpstreamAPIKeys) + m.secretSource = mappedSource + } else if ms, ok := m.secretSource.(*MappedSecretSource); ok { + ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey) + ms.InvalidateCache() + ms.UpdateMappings(settings.UpstreamAPIKeys) } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { + // Legacy path: wrap existing MultiSourceSecret with MappedSecretSource ms.UpdateExplicitKey(settings.UpstreamAPIKey) ms.InvalidateCache() + mappedSource := NewMappedSecretSource(ms) + mappedSource.UpdateMappings(settings.UpstreamAPIKeys) + m.secretSource = mappedSource } proxy, err := createReverseProxy(upstreamURL, m.secretSource) @@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b return oldKey != newKey } +// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings. +func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool { + if old == nil { + return len(new.UpstreamAPIKeys) > 0 + } + + if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) { + return true + } + + // Build map for comparison: upstreamKey -> set of clientKeys + type entryInfo struct { + upstreamKey string + clientKeys map[string]struct{} + } + oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys)) + for i, entry := range old.UpstreamAPIKeys { + clientKeys := make(map[string]struct{}, len(entry.APIKeys)) + for _, k := range entry.APIKeys { + trimmed := strings.TrimSpace(k) + if trimmed == "" { + continue + } + clientKeys[trimmed] = struct{}{} + } + oldEntries[i] = entryInfo{ + upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey), + clientKeys: clientKeys, + } + } + + for i, newEntry := range new.UpstreamAPIKeys { + if i >= len(oldEntries) { + return true + } + oldE := oldEntries[i] + if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey { + return true + } + newKeys := make(map[string]struct{}, len(newEntry.APIKeys)) + for _, k := range newEntry.APIKeys { + trimmed := strings.TrimSpace(k) + if trimmed == "" { + continue + } + newKeys[trimmed] = struct{}{} + } + if len(newKeys) != len(oldE.clientKeys) { + return true + } + for k := range newKeys { + if _, ok := oldE.clientKeys[k]; !ok { + return true + } + } + } + + return false +} + // GetModelMapper returns the model mapper instance (for testing/debugging). func (m *AmpModule) GetModelMapper() *DefaultModelMapper { return m.modelMapper diff --git a/internal/api/modules/amp/amp_test.go b/internal/api/modules/amp/amp_test.go index fcfc3174c..430c4b62a 100644 --- a/internal/api/modules/amp/amp_test.go +++ b/internal/api/modules/amp/amp_test.go @@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) { }) } } + +func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) { + m := &AmpModule{} + + oldCfg := &config.AmpCode{ + UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ + {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}}, + }, + } + newCfg := &config.AmpCode{ + UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ + {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}}, + }, + } + + if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) { + t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates") + } +} + +func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) { + m := &AmpModule{} + + oldCfg := &config.AmpCode{ + UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ + {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}}, + }, + } + newCfg := &config.AmpCode{ + UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{ + {UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}}, + }, + } + + if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) { + t.Fatal("expected no change when only whitespace/empty entries differ") + } +} diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 3c4ef3082..c460a0d60 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -15,6 +15,33 @@ import ( log "github.com/sirupsen/logrus" ) +func removeQueryValuesMatching(req *http.Request, key string, match string) { + if req == nil || req.URL == nil || match == "" { + return + } + + q := req.URL.Query() + values, ok := q[key] + if !ok || len(values) == 0 { + return + } + + kept := make([]string, 0, len(values)) + for _, v := range values { + if v == match { + continue + } + kept = append(kept, v) + } + + if len(kept) == 0 { + q.Del(key) + } else { + q[key] = kept + } + req.URL.RawQuery = q.Encode() +} + // readCloser wraps a reader and forwards Close to a separate closer. // Used to restore peeked bytes while preserving upstream body Close behavior. type readCloser struct { @@ -45,6 +72,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi // We will set our own Authorization using the configured upstream-api-key req.Header.Del("Authorization") req.Header.Del("X-Api-Key") + req.Header.Del("X-Goog-Api-Key") + + // Remove query-based credentials if they match the authenticated client API key. + // This prevents leaking client auth material to the Amp upstream while avoiding + // breaking unrelated upstream query parameters. + clientKey := getClientAPIKeyFromContext(req.Context()) + removeQueryValuesMatching(req, "key", clientKey) + removeQueryValuesMatching(req, "auth_token", clientKey) // Preserve correlation headers for debugging if req.Header.Get("X-Request-ID") == "" { diff --git a/internal/api/modules/amp/proxy_test.go b/internal/api/modules/amp/proxy_test.go index 95edc12d2..ff23e3986 100644 --- a/internal/api/modules/amp/proxy_test.go +++ b/internal/api/modules/amp/proxy_test.go @@ -3,11 +3,15 @@ package amp import ( "bytes" "compress/gzip" + "context" "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" ) // Helper: compress data with gzip @@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) { } } +func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) { + type captured struct { + headers http.Header + query string + } + got := make(chan captured, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery} + w.WriteHeader(200) + w.Write([]byte(`ok`)) + })) + defer upstream.Close() + + proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream")) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate clientAPIKeyMiddleware injection (per-request) + ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key") + proxy.ServeHTTP(w, r.WithContext(ctx)) + })) + defer srv.Close() + + req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer client-key") + req.Header.Set("X-Api-Key", "client-key") + req.Header.Set("X-Goog-Api-Key", "client-key") + + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + c := <-got + + // These are client-provided credentials and must not reach the upstream. + if v := c.headers.Get("X-Goog-Api-Key"); v != "" { + t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v) + } + + // We inject upstream Authorization/X-Api-Key, so the client auth must not survive. + if v := c.headers.Get("Authorization"); v != "Bearer upstream" { + t.Fatalf("Authorization should be upstream-injected, got: %q", v) + } + if v := c.headers.Get("X-Api-Key"); v != "upstream" { + t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v) + } + + // Query-based credentials should be stripped only when they match the authenticated client key. + // Should keep unrelated values and parameters. + if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") { + t.Fatalf("query credentials should be stripped, got raw query: %q", c.query) + } + if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") { + t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query) + } +} + +func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) { + gotHeaders := make(chan http.Header, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders <- r.Header.Clone() + w.WriteHeader(200) + w.Write([]byte(`ok`)) + })) + defer upstream.Close() + + defaultSource := NewStaticSecretSource("default") + mapped := NewMappedSecretSource(defaultSource) + mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ + { + UpstreamAPIKey: "u1", + APIKeys: []string{"k1"}, + }, + }) + + proxy, err := createReverseProxy(upstream.URL, mapped) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate clientAPIKeyMiddleware injection (per-request) + ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1") + proxy.ServeHTTP(w, r.WithContext(ctx)) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + hdr := <-gotHeaders + if hdr.Get("X-Api-Key") != "u1" { + t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key")) + } + if hdr.Get("Authorization") != "Bearer u1" { + t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization")) + } +} + +func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) { + gotHeaders := make(chan http.Header, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders <- r.Header.Clone() + w.WriteHeader(200) + w.Write([]byte(`ok`)) + })) + defer upstream.Close() + + defaultSource := NewStaticSecretSource("default") + mapped := NewMappedSecretSource(defaultSource) + mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ + { + UpstreamAPIKey: "u1", + APIKeys: []string{"k1"}, + }, + }) + + proxy, err := createReverseProxy(upstream.URL, mapped) + if err != nil { + t.Fatal(err) + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2") + proxy.ServeHTTP(w, r.WithContext(ctx)) + })) + defer srv.Close() + + res, err := http.Get(srv.URL + "/test") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + hdr := <-gotHeaders + if hdr.Get("X-Api-Key") != "default" { + t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key")) + } + if hdr.Get("Authorization") != "Bearer default" { + t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization")) + } +} + func TestReverseProxy_ErrorHandler(t *testing.T) { // Point proxy to a non-routable address to trigger error proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource("")) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index a37c0a155..456a50ac1 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -1,6 +1,7 @@ package amp import ( + "context" "errors" "net" "net/http" @@ -16,6 +17,37 @@ import ( log "github.com/sirupsen/logrus" ) +// clientAPIKeyContextKey is the context key used to pass the client API key +// from gin.Context to the request context for SecretSource lookup. +type clientAPIKeyContextKey struct{} + +// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"] +// into the request context so that SecretSource can look it up for per-client upstream routing. +func clientAPIKeyMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Extract the client API key from gin context (set by AuthMiddleware) + if apiKey, exists := c.Get("apiKey"); exists { + if keyStr, ok := apiKey.(string); ok && keyStr != "" { + // Inject into request context for SecretSource.Get(ctx) to read + ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr) + c.Request = c.Request.WithContext(ctx) + } + } + c.Next() + } +} + +// getClientAPIKeyFromContext retrieves the client API key from request context. +// Returns empty string if not present. +func getClientAPIKeyFromContext(ctx context.Context) string { + if val := ctx.Value(clientAPIKeyContextKey{}); val != nil { + if keyStr, ok := val.(string); ok { + return keyStr + } + } + return "" +} + // localhostOnlyMiddleware returns a middleware that dynamically checks the module's // localhost restriction setting. This allows hot-reload of the restriction without restarting. func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc { @@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings") } + // Inject client API key into request context for per-client upstream routing + ampAPI.Use(clientAPIKeyMiddleware()) + // Dynamic proxy handler that uses m.getProxy() for hot-reload support proxyHandler := func(c *gin.Context) { // Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces @@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha if authWithBypass != nil { rootMiddleware = append(rootMiddleware, authWithBypass) } + // Add clientAPIKeyMiddleware after auth for per-client upstream routing + rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware()) engine.GET("/threads", append(rootMiddleware, proxyHandler)...) engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/docs", append(rootMiddleware, proxyHandler)...) @@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han if auth != nil { ampProviders.Use(auth) } + // Inject client API key into request context for per-client upstream routing + ampProviders.Use(clientAPIKeyMiddleware()) provider := ampProviders.Group("/:provider") diff --git a/internal/api/modules/amp/secret.go b/internal/api/modules/amp/secret.go index a7ebf3cb9..f91c72ba9 100644 --- a/internal/api/modules/amp/secret.go +++ b/internal/api/modules/amp/secret.go @@ -9,6 +9,9 @@ import ( "strings" "sync" "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + log "github.com/sirupsen/logrus" ) // SecretSource provides Amp API keys with configurable precedence and caching @@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource { func (s *StaticSecretSource) Get(ctx context.Context) (string, error) { return s.key, nil } + +// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping. +// When a request context contains a client API key that matches a configured mapping, +// the corresponding upstream key is returned. Otherwise, falls back to the default source. +type MappedSecretSource struct { + defaultSource SecretSource + mu sync.RWMutex + lookup map[string]string // clientKey -> upstreamKey +} + +// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source. +func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource { + return &MappedSecretSource{ + defaultSource: defaultSource, + lookup: make(map[string]string), + } +} + +// Get retrieves the Amp API key, checking per-client mappings first. +// If the request context contains a client API key that matches a configured mapping, +// returns the corresponding upstream key. Otherwise, falls back to the default source. +func (s *MappedSecretSource) Get(ctx context.Context) (string, error) { + // Try to get client API key from request context + clientKey := getClientAPIKeyFromContext(ctx) + if clientKey != "" { + s.mu.RLock() + if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" { + s.mu.RUnlock() + return upstreamKey, nil + } + s.mu.RUnlock() + } + + // Fall back to default source + return s.defaultSource.Get(ctx) +} + +// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries. +// If the same client key appears in multiple entries, logs a warning and uses the first one. +func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) { + newLookup := make(map[string]string) + + for _, entry := range entries { + upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey) + if upstreamKey == "" { + continue + } + for _, clientKey := range entry.APIKeys { + trimmedKey := strings.TrimSpace(clientKey) + if trimmedKey == "" { + continue + } + if _, exists := newLookup[trimmedKey]; exists { + // Log warning for duplicate client key, first one wins + log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.") + continue + } + newLookup[trimmedKey] = upstreamKey + } + } + + s.mu.Lock() + s.lookup = newLookup + s.mu.Unlock() +} + +// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable). +func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) { + if ms, ok := s.defaultSource.(*MultiSourceSecret); ok { + ms.UpdateExplicitKey(key) + } +} + +// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable). +func (s *MappedSecretSource) InvalidateCache() { + if ms, ok := s.defaultSource.(*MultiSourceSecret); ok { + ms.InvalidateCache() + } +} diff --git a/internal/api/modules/amp/secret_test.go b/internal/api/modules/amp/secret_test.go index 9c3e820a1..6a6f6ba26 100644 --- a/internal/api/modules/amp/secret_test.go +++ b/internal/api/modules/amp/secret_test.go @@ -8,6 +8,10 @@ import ( "sync" "testing" "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" ) func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { @@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) { t.Fatalf("after cache expiry, expected new-value, got %q", got3) } } + +func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) { + defaultSource := NewStaticSecretSource("default") + s := NewMappedSecretSource(defaultSource) + s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ + { + UpstreamAPIKey: "u1", + APIKeys: []string{"k1"}, + }, + }) + + ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "u1" { + t.Fatalf("want u1, got %q", got) + } + + ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2") + got, err = s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "default" { + t.Fatalf("want default fallback, got %q", got) + } +} + +func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) { + defaultSource := NewStaticSecretSource("default") + s := NewMappedSecretSource(defaultSource) + s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ + { + UpstreamAPIKey: "u1", + APIKeys: []string{"k1"}, + }, + { + UpstreamAPIKey: "u2", + APIKeys: []string{"k1"}, + }, + }) + + ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") + got, err := s.Get(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "u1" { + t.Fatalf("want u1 (first wins), got %q", got) + } +} + +func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) { + hook := test.NewLocal(log.StandardLogger()) + defer hook.Reset() + + defaultSource := NewStaticSecretSource("default") + s := NewMappedSecretSource(defaultSource) + s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ + { + UpstreamAPIKey: "u1", + APIKeys: []string{"k1"}, + }, + { + UpstreamAPIKey: "u2", + APIKeys: []string{"k1"}, + }, + }) + + foundWarning := false + for _, entry := range hook.AllEntries() { + if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." { + foundWarning = true + break + } + } + if !foundWarning { + t.Fatal("expected warning log for duplicate client key, but none was found") + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 99623e764..940c98bcd 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -551,6 +551,10 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) + mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys) + mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys) + mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys) + mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/config/config.go b/internal/config/config.go index b92655a3b..e76a22d59 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -169,6 +169,11 @@ type AmpCode struct { // UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls. UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` + // UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys. + // When a client authenticates with a key that matches an entry, that upstream key is used. + // If no match is found, falls back to UpstreamAPIKey (default behavior). + UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"` + // RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.) // to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by // browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient). @@ -184,6 +189,17 @@ type AmpCode struct { ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"` } +// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key. +// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey +// is used for the upstream Amp request. +type AmpUpstreamAPIKeyEntry struct { + // UpstreamAPIKey is the API key to use when proxying to the Amp upstream. + UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` + + // APIKeys are the client API keys (from top-level api-keys) that map to this upstream key. + APIKeys []string `yaml:"api-keys" json:"api-keys"` +} + // PayloadConfig defines default and override parameter rules applied to provider payloads. type PayloadConfig struct { // Default defines rules that only set parameters when they are missing in the payload. diff --git a/internal/util/gemini_schema_test.go b/internal/util/gemini_schema_test.go index 01c8f12fd..69adbcdb7 100644 --- a/internal/util/gemini_schema_test.go +++ b/internal/util/gemini_schema_test.go @@ -614,71 +614,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) { - // propertyNames is used to validate object property names (e.g., must match a pattern) - // Gemini doesn't support this keyword and will reject requests containing it - input := `{ - "type": "object", - "properties": { - "metadata": { - "type": "object", - "propertyNames": { - "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$" - }, - "additionalProperties": { - "type": "string" - } - } - } - }` - - expected := `{ - "type": "object", - "properties": { - "metadata": { - "type": "object" - } - } - }` - - result := CleanJSONSchemaForGemini(input) - compareJSON(t, expected, result) - - // Verify propertyNames is completely removed - if strings.Contains(result, "propertyNames") { - t.Errorf("propertyNames keyword should be removed, got: %s", result) - } -} - -func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) { - // Test deeply nested propertyNames (as seen in real Claude tool schemas) - input := `{ - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "type": "object", - "properties": { - "config": { - "type": "object", - "propertyNames": { - "type": "string" - } - } - } - } - } - } - }` - - result := CleanJSONSchemaForGemini(input) - - if strings.Contains(result, "propertyNames") { - t.Errorf("Nested propertyNames should be removed, got: %s", result) - } -} - func compareJSON(t *testing.T, expectedJSON, actualJSON string) { var expMap, actMap map[string]interface{} errExp := json.Unmarshal([]byte(expectedJSON), &expMap) diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index ecc15b391..1ce601516 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -185,6 +185,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings { changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings)) } + oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys) + newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys) + if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) { + changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount)) + } if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { changes = append(changes, entries...) @@ -301,3 +306,43 @@ func formatProxyURL(raw string) string { } return scheme + "://" + host } + +func equalStringSet(a, b []string) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + aSet := make(map[string]struct{}, len(a)) + for _, k := range a { + aSet[strings.TrimSpace(k)] = struct{}{} + } + bSet := make(map[string]struct{}, len(b)) + for _, k := range b { + bSet[strings.TrimSpace(k)] = struct{}{} + } + if len(aSet) != len(bSet) { + return false + } + for k := range aSet { + if _, ok := bSet[k]; !ok { + return false + } + } + return true +} + +// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality. +// Comparison is done by count and content (upstream key and client keys). +func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) { + return false + } + if !equalStringSet(a[i].APIKeys, b[i].APIKeys) { + return false + } + } + return true +} diff --git a/test/amp_management_test.go b/test/amp_management_test.go index 19450dbff..e384ef0e8 100644 --- a/test/amp_management_test.go +++ b/test/amp_management_test.go @@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine { mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey) mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey) mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey) + mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys) + mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys) + mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys) + mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys) mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost) mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost) mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings) @@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) { } } +func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) { + h, configPath := newAmpTestHandler(t) + r := setupAmpRouter(h) + + body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } + + // Verify it was persisted to disk + loaded, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("failed to load config from disk: %v", err) + } + if len(loaded.AmpCode.UpstreamAPIKeys) != 1 { + t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys)) + } + entry := loaded.AmpCode.UpstreamAPIKeys[0] + if entry.UpstreamAPIKey != "u1" { + t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey) + } + if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" { + t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys) + } + + // Verify it is returned by GET /ampcode + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + var resp map[string]config.AmpCode + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" { + t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got) + } +} + +func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + // Seed with one entry + putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}` + req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String()) + } + + deleteBody := `{"value":[]}` + req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + var resp map[string][]config.AmpUpstreamAPIKeyEntry + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 { + t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"]) + } +} + // TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key. func TestDeleteAmpUpstreamAPIKey(t *testing.T) { h, _ := newAmpTestHandler(t) From 6b2550de7ed9a062a4a9cf2c6395b34c6242122d Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 29 Dec 2025 16:34:16 +0800 Subject: [PATCH 15/17] feat(api): add id token claims extraction for codex auth entries --- .../api/handlers/management/auth_files.go | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 41a4fde40..d98c759ce 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -427,9 +427,59 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { log.WithError(err).Warnf("failed to stat auth file %s", path) } } + if claims := extractCodexIDTokenClaims(auth); claims != nil { + entry["id_token"] = claims + } return entry } +func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H { + if auth == nil || auth.Metadata == nil { + return nil + } + if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { + return nil + } + idTokenRaw, ok := auth.Metadata["id_token"].(string) + if !ok { + return nil + } + idToken := strings.TrimSpace(idTokenRaw) + if idToken == "" { + return nil + } + claims, err := codex.ParseJWTToken(idToken) + if err != nil || claims == nil { + return nil + } + + result := gin.H{} + if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" { + result["plan_type"] = v + } + if v := strings.TrimSpace(claims.CodexAuthInfo.UserID); v != "" { + result["user_id"] = v + } + + if len(claims.CodexAuthInfo.Organizations) > 0 { + orgs := make([]gin.H, 0, len(claims.CodexAuthInfo.Organizations)) + for _, org := range claims.CodexAuthInfo.Organizations { + orgs = append(orgs, gin.H{ + "id": strings.TrimSpace(org.ID), + "title": strings.TrimSpace(org.Title), + "role": strings.TrimSpace(org.Role), + "is_default": org.IsDefault, + }) + } + result["organizations"] = orgs + } + + if len(result) == 0 { + return nil + } + return result +} + func authEmail(auth *coreauth.Auth) string { if auth == nil { return "" From 42f573182b113e2fcfcad59134bf50ef553c4074 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 29 Dec 2025 19:48:02 +0800 Subject: [PATCH 16/17] refactor(api): simplify codex id token claims extraction --- .../api/handlers/management/auth_files.go | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index d98c759ce..e0904ab6e 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -454,25 +454,12 @@ func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H { } result := gin.H{} + if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" { + result["chatgpt_account_id"] = v + } if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" { result["plan_type"] = v } - if v := strings.TrimSpace(claims.CodexAuthInfo.UserID); v != "" { - result["user_id"] = v - } - - if len(claims.CodexAuthInfo.Organizations) > 0 { - orgs := make([]gin.H, 0, len(claims.CodexAuthInfo.Organizations)) - for _, org := range claims.CodexAuthInfo.Organizations { - orgs = append(orgs, gin.H{ - "id": strings.TrimSpace(org.ID), - "title": strings.TrimSpace(org.Title), - "role": strings.TrimSpace(org.Role), - "is_default": org.IsDefault, - }) - } - result["organizations"] = orgs - } if len(result) == 0 { return nil From 6c9b01968033d454d9a3d4ecb18c8326581ffed6 Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Mon, 29 Dec 2025 23:55:59 +0800 Subject: [PATCH 17/17] fix(antigravity): inject required placeholder when properties exist without required --- internal/util/gemini_schema.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index 2daf0a79b..33df61f91 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string { } // addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas. -// Claude VALIDATED mode requires at least one property in tool schemas. +// Claude VALIDATED mode requires at least one required property in tool schemas. func addEmptySchemaPlaceholder(jsonStr string) string { // Find all "type" fields paths := findPaths(jsonStr, "type") @@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string { // Check if properties exists and is empty or missing propsPath := joinPath(parentPath, "properties") propsVal := gjson.Get(jsonStr, propsPath) + reqPath := joinPath(parentPath, "required") + reqVal := gjson.Get(jsonStr, reqPath) + hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0 needsPlaceholder := false if !propsVal.Exists() { @@ -381,8 +384,17 @@ func addEmptySchemaPlaceholder(jsonStr string) string { jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool") // Add to required array - reqPath := joinPath(parentPath, "required") jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"}) + continue + } + + // If schema has properties but none are required, add a minimal placeholder. + if propsVal.IsObject() && !hasRequiredProperties { + placeholderPath := joinPath(propsPath, "_") + if !gjson.Get(jsonStr, placeholderPath).Exists() { + jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean") + } + jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"}) } }