diff --git a/.gitignore b/.gitignore index ef2d935ae..9e730c987 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ pgstore/* gitstore/* objectstore/* static/* +refs/* # Authentication data auths/* @@ -30,3 +31,7 @@ GEMINI.md .vscode/* .claude/* .serena/* + +# macOS +.DS_Store +._* diff --git a/config.example.yaml b/config.example.yaml index 9dfca5bca..3086ca783 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -55,6 +55,28 @@ quota-exceeded: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false +# Amp CLI Integration +# Configure upstream URL for Amp CLI OAuth and management features +#amp-upstream-url: "https://ampcode.com" + +# Optional: Override API key for Amp upstream (otherwise uses env or file) +#amp-upstream-api-key: "" + +# Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended) +#amp-restrict-management-to-localhost: true + +# Amp Model Mappings +# Route unavailable Amp models to alternative models available in your local proxy. +# Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) +# but you have a similar model available (e.g., Claude Sonnet 4). +#amp-model-mappings: +# - from: "claude-opus-4.5" # Model requested by Amp CLI +# to: "claude-sonnet-4" # Route to this available model instead +# - from: "gpt-5" +# to: "gemini-2.5-pro" +# - from: "claude-3-opus-20240229" +# to: "claude-3-5-sonnet-20241022" + # Gemini API keys (preferred) #gemini-api-key: # - api-key: "AIzaSy...01" @@ -123,6 +145,19 @@ ws-auth: false # - name: "moonshotai/kimi-k2:free" # The actual model name. # alias: "kimi-k2" # The alias used in the API. +# Vertex API keys (Vertex-compatible endpoints, use API key + base URL) +#vertex-api-key: +# - api-key: "vk-123..." # x-goog-api-key header +# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api +# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override +# headers: +# X-Custom-Header: "custom-value" +# models: # optional: map aliases to upstream model names +# - name: "gemini-2.0-flash" # upstream model name +# alias: "vertex-flash" # client-visible alias +# - name: "gemini-1.5-pro" +# alias: "vertex-pro" + #payload: # Optional payload configuration # default: # Default rules only set parameters when they are missing in the payload. # - models: diff --git a/docs/amp-cli-integration.md b/docs/amp-cli-integration.md index 41cea66a2..feac51fc4 100644 --- a/docs/amp-cli-integration.md +++ b/docs/amp-cli-integration.md @@ -8,6 +8,7 @@ This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, - [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate) - [Architecture](#architecture) - [Configuration](#configuration) + - [Model Mapping Configuration](#model-mapping-configuration) - [Setup](#setup) - [Usage](#usage) - [Troubleshooting](#troubleshooting) @@ -21,6 +22,7 @@ The Amp CLI integration adds specialized routing to support Amp's API patterns w - **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers - **Management proxy**: Forwards OAuth and account management requests to Amp's control plane - **Smart fallback**: Automatically routes unconfigured models to ampcode.com +- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`) - **Secret management**: Configurable precedence (config > env > file) with 5-minute caching - **Security-first**: Management routes restricted to localhost by default - **Automatic gzip handling**: Decompresses responses from Amp upstream @@ -75,7 +77,10 @@ Amp CLI/IDE │ ↓ │ ├─ Model configured locally? │ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers) - │ │ NO → Forward to ampcode.com (reverse proxy) + │ │ NO ↓ + │ │ ├─ Model mapping configured? + │ │ │ YES → Rewrite model → Use local handler (free) + │ │ │ NO → Forward to ampcode.com (uses Amp credits) │ ↓ │ Response │ @@ -115,6 +120,49 @@ amp-upstream-url: "https://ampcode.com" amp-restrict-management-to-localhost: true ``` +### Model Mapping Configuration + +When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally. + +```yaml +# Route unavailable models to alternatives +amp-model-mappings: + # Example: Route Claude Opus 4.5 requests to Claude Sonnet 4 + - from: "claude-opus-4.5" + to: "claude-sonnet-4" + + # Example: Route GPT-5 requests to Gemini 2.5 Pro + - from: "gpt-5" + to: "gemini-2.5-pro" + + # Example: Map older model names to newer versions + - from: "claude-3-opus-20240229" + to: "claude-3-5-sonnet-20241022" +``` + +**How it works:** + +1. Amp CLI requests a model (e.g., `claude-opus-4.5`) +2. CLIProxyAPI checks if a local provider is available for that model +3. If not available, it checks the model mappings +4. If a mapping exists, the request is rewritten to use the target model +5. The request is then handled locally (free, using your OAuth subscription) + +**Benefits:** +- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com +- **Hot-reload**: Mappings can be updated without restarting the proxy +- **Structured logging**: Clear logs show when mappings are applied + +**Routing Decision Logs:** + +The proxy logs each routing decision with structured fields: + +``` +[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free) +[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free) +[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits) +``` + ### Secret Resolution Precedence The Amp module resolves API keys using this precedence order: @@ -301,11 +349,14 @@ When Amp requests a model: 1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider? 2. **If YES**: Route to local handler (use your OAuth subscription) -3. **If NO**: Forward to ampcode.com (use Amp's default routing) +3. **If NO**: Check if a model mapping exists +4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free) +5. **If no mapping**: Forward to ampcode.com (uses Amp credits) This enables seamless mixed usage: - Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions -- Models you haven't configured → Amp's default providers +- Models with mappings configured → Routed to alternative local models (free) +- Models you haven't configured and have no mapping → Amp's default providers (uses credits) ### Example API Calls diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index 0086d1798..281fda656 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -23,11 +23,13 @@ type Option func(*AmpModule) // - Reverse proxy to Amp control plane for OAuth/management // - Provider-specific route aliases (/api/provider/{provider}/...) // - Automatic gzip decompression for misconfigured upstreams +// - Model mapping for routing unavailable models to alternatives type AmpModule struct { secretSource SecretSource proxy *httputil.ReverseProxy accessManager *sdkaccess.Manager authMiddleware_ gin.HandlerFunc + modelMapper *DefaultModelMapper enabled bool registerOnce sync.Once } @@ -101,6 +103,9 @@ func (m *AmpModule) Register(ctx modules.Context) error { // Use registerOnce to ensure routes are only registered once var regErr error m.registerOnce.Do(func() { + // Initialize model mapper from config (for routing unavailable models to alternatives) + m.modelMapper = NewModelMapper(ctx.Config.AmpModelMappings) + // Always register provider aliases - these work without an upstream m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) @@ -159,8 +164,16 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { // OnConfigUpdated handles configuration updates. // Currently requires restart for URL changes (could be enhanced for dynamic updates). func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { + // Update model mappings (hot-reload supported) + if m.modelMapper != nil { + log.Infof("amp config updated: reloading %d model mapping(s)", len(cfg.AmpModelMappings)) + m.modelMapper.UpdateMappings(cfg.AmpModelMappings) + } else { + log.Warnf("amp model mapper not initialized, skipping model mapping update") + } + if !m.enabled { - log.Debug("Amp routing not enabled, skipping config update") + log.Debug("Amp routing not enabled, skipping other config updates") return nil } @@ -181,3 +194,8 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { log.Debug("Amp config updated (restart required for URL changes)") return nil } + +// GetModelMapper returns the model mapper instance (for testing/debugging). +func (m *AmpModule) GetModelMapper() *DefaultModelMapper { + return m.modelMapper +} diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index e7b28986d..17c60708c 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -6,16 +6,75 @@ import ( "io" "net/http/httputil" "strings" + "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) +// AmpRouteType represents the type of routing decision made for an Amp request +type AmpRouteType string + +const ( + // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free) + RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER" + // RouteTypeModelMapping indicates the request was remapped to another available model (free) + RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING" + // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits) + RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS" + // RouteTypeNoProvider indicates no provider or fallback available + RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" +) + +// logAmpRouting logs the routing decision for an Amp request with structured fields +func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { + fields := log.Fields{ + "component": "amp-routing", + "route_type": string(routeType), + "requested_model": requestedModel, + "path": path, + "timestamp": time.Now().Format(time.RFC3339), + } + + if resolvedModel != "" && resolvedModel != requestedModel { + fields["resolved_model"] = resolvedModel + } + if provider != "" { + fields["provider"] = provider + } + + switch routeType { + case RouteTypeLocalProvider: + fields["cost"] = "free" + fields["source"] = "local_oauth" + log.WithFields(fields).Infof("[AMP] Using local provider for model: %s", requestedModel) + + case RouteTypeModelMapping: + fields["cost"] = "free" + fields["source"] = "local_oauth" + fields["mapping"] = requestedModel + " -> " + resolvedModel + log.WithFields(fields).Infof("[AMP] Model mapped: %s -> %s", requestedModel, resolvedModel) + + case RouteTypeAmpCredits: + fields["cost"] = "amp_credits" + fields["source"] = "ampcode.com" + fields["model_id"] = requestedModel // Explicit model_id for easy config reference + log.WithFields(fields).Warnf("[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"\"}]", requestedModel, requestedModel) + + case RouteTypeNoProvider: + fields["cost"] = "none" + fields["source"] = "error" + fields["model_id"] = requestedModel // Explicit model_id for easy config reference + log.WithFields(fields).Warnf("[AMP] No provider available for model_id: %s", requestedModel) + } +} + // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper } // NewFallbackHandler creates a new fallback handler wrapper @@ -26,10 +85,25 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler } } +// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { + return &FallbackHandler{ + getProxy: getProxy, + modelMapper: mapper, + } +} + +// SetModelMapper sets the model mapper for this handler (allows late binding) +func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { + fh.modelMapper = mapper +} + // WrapHandler wraps a gin.HandlerFunc with fallback logic // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { + requestPath := c.Request.URL.Path + // Read the request body to extract the model name bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { @@ -55,12 +129,33 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Check if we have providers for this model providers := util.GetProviderName(normalizedModel) + // Track resolved model for logging (may change if mapping is applied) + resolvedModel := normalizedModel + usedMapping := false + if len(providers) == 0 { - // No providers configured - check if we have a proxy for fallback + // No providers configured - check if we have a model mapping + if fh.modelMapper != nil { + if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { + // Mapping found - rewrite the model in request body + bodyBytes = rewriteModelInBody(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + resolvedModel = mappedModel + usedMapping = true + + // Get providers for the mapped model + providers = util.GetProviderName(mappedModel) + + // Continue to handler with remapped model + goto handleRequest + } + } + + // No mapping found - check if we have a proxy for fallback proxy := fh.getProxy() if proxy != nil { - // Fallback to ampcode.com - log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName) + // Log: Forwarding to ampcode.com (uses Amp credits) + logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath) // Restore body again for the proxy c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) @@ -71,7 +166,23 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } // No proxy available, let the normal handler return the error - log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName) + logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) + } + + handleRequest: + + // Log the routing decision + providerName := "" + if len(providers) > 0 { + providerName = providers[0] + } + + if usedMapping { + // Log: Model was mapped to another model + logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) + } else if len(providers) > 0 { + // Log: Using local provider (free) + logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) } // Providers available or no proxy for fallback, restore body and use normal handler @@ -91,6 +202,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } } +// rewriteModelInBody replaces the model name in a JSON request body +func rewriteModelInBody(body []byte, newModel string) []byte { + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err) + return body + } + + if _, exists := payload["model"]; exists { + payload["model"] = newModel + newBody, err := json.Marshal(payload) + if err != nil { + log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err) + return body + } + return newBody + } + + return body +} + // extractModelFromRequest attempts to extract the model name from various request formats func extractModelFromRequest(body []byte, c *gin.Context) string { // First try to parse from JSON body (OpenAI, Claude, etc.) diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go new file mode 100644 index 000000000..c07f41c40 --- /dev/null +++ b/internal/api/modules/amp/model_mapping.go @@ -0,0 +1,113 @@ +// Package amp provides model mapping functionality for routing Amp CLI requests +// to alternative models when the requested model is not available locally. +package amp + +import ( + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// ModelMapper provides model name mapping/aliasing for Amp CLI requests. +// When an Amp request comes in for a model that isn't available locally, +// this mapper can redirect it to an alternative model that IS available. +type ModelMapper interface { + // MapModel returns the target model name if a mapping exists and the target + // model has available providers. Returns empty string if no mapping applies. + MapModel(requestedModel string) string + + // UpdateMappings refreshes the mapping configuration (for hot-reload). + UpdateMappings(mappings []config.AmpModelMapping) +} + +// DefaultModelMapper implements ModelMapper with thread-safe mapping storage. +type DefaultModelMapper struct { + mu sync.RWMutex + mappings map[string]string // from -> to (normalized lowercase keys) +} + +// NewModelMapper creates a new model mapper with the given initial mappings. +func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { + m := &DefaultModelMapper{ + mappings: make(map[string]string), + } + m.UpdateMappings(mappings) + return m +} + +// MapModel checks if a mapping exists for the requested model and if the +// target model has available local providers. Returns the mapped model name +// or empty string if no valid mapping exists. +func (m *DefaultModelMapper) MapModel(requestedModel string) string { + if requestedModel == "" { + return "" + } + + m.mu.RLock() + defer m.mu.RUnlock() + + // Normalize the requested model for lookup + normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel)) + + // Check for direct mapping + targetModel, exists := m.mappings[normalizedRequest] + if !exists { + return "" + } + + // Verify target model has available providers + providers := util.GetProviderName(targetModel) + if len(providers) == 0 { + log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) + return "" + } + + // Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go + log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel) + return targetModel +} + +// UpdateMappings refreshes the mapping configuration from config. +// This is called during initialization and on config hot-reload. +func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear and rebuild mappings + m.mappings = make(map[string]string, len(mappings)) + + for _, mapping := range mappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + + if from == "" || to == "" { + log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to) + continue + } + + // Store with normalized lowercase key for case-insensitive lookup + normalizedFrom := strings.ToLower(from) + m.mappings[normalizedFrom] = to + + log.Debugf("amp model mapping registered: %s -> %s", from, to) + } + + if len(m.mappings) > 0 { + log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) + } +} + +// GetMappings returns a copy of current mappings (for debugging/status). +func (m *DefaultModelMapper) GetMappings() map[string]string { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]string, len(m.mappings)) + for k, v := range m.mappings { + result[k] = v + } + return result +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go new file mode 100644 index 000000000..c11d61bd7 --- /dev/null +++ b/internal/api/modules/amp/model_mapping_test.go @@ -0,0 +1,186 @@ +package amp + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestNewModelMapper(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + {From: "gpt-5", To: "gemini-2.5-pro"}, + } + + mapper := NewModelMapper(mappings) + if mapper == nil { + t.Fatal("Expected non-nil mapper") + } + + result := mapper.GetMappings() + if len(result) != 2 { + t.Errorf("Expected 2 mappings, got %d", len(result)) + } +} + +func TestNewModelMapper_Empty(t *testing.T) { + mapper := NewModelMapper(nil) + if mapper == nil { + t.Fatal("Expected non-nil mapper") + } + + result := mapper.GetMappings() + if len(result) != 0 { + t.Errorf("Expected 0 mappings, got %d", len(result)) + } +} + +func TestModelMapper_MapModel_NoProvider(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Without a registered provider for the target, mapping should return empty + result := mapper.MapModel("claude-opus-4.5") + if result != "" { + t.Errorf("Expected empty result when target has no provider, got %s", result) + } +} + +func TestModelMapper_MapModel_WithProvider(t *testing.T) { + // Register a mock provider for the target model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client") + + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // With a registered provider, mapping should work + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client2") + + mappings := []config.AmpModelMapping{ + {From: "Claude-Opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Should match case-insensitively + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_MapModel_NotFound(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Unknown model should return empty + result := mapper.MapModel("unknown-model") + if result != "" { + t.Errorf("Expected empty for unknown model, got %s", result) + } +} + +func TestModelMapper_MapModel_EmptyInput(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("") + if result != "" { + t.Errorf("Expected empty for empty input, got %s", result) + } +} + +func TestModelMapper_UpdateMappings(t *testing.T) { + mapper := NewModelMapper(nil) + + // Initially empty + if len(mapper.GetMappings()) != 0 { + t.Error("Expected 0 initial mappings") + } + + // Update with new mappings + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "model-a", To: "model-b"}, + {From: "model-c", To: "model-d"}, + }) + + result := mapper.GetMappings() + if len(result) != 2 { + t.Errorf("Expected 2 mappings after update, got %d", len(result)) + } + + // Update again should replace, not append + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "model-x", To: "model-y"}, + }) + + result = mapper.GetMappings() + if len(result) != 1 { + t.Errorf("Expected 1 mapping after second update, got %d", len(result)) + } +} + +func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) { + mapper := NewModelMapper(nil) + + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "", To: "model-b"}, // Invalid: empty from + {From: "model-a", To: ""}, // Invalid: empty to + {From: " ", To: "model-b"}, // Invalid: whitespace from + {From: "model-c", To: "model-d"}, // Valid + }) + + result := mapper.GetMappings() + if len(result) != 1 { + t.Errorf("Expected 1 valid mapping, got %d", len(result)) + } +} + +func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "model-a", To: "model-b"}, + } + + mapper := NewModelMapper(mappings) + + // Get mappings and modify the returned map + result := mapper.GetMappings() + result["new-key"] = "new-value" + + // Original should be unchanged + original := mapper.GetMappings() + if len(original) != 1 { + t.Errorf("Expected original to have 1 mapping, got %d", len(original)) + } + if _, exists := original["new-key"]; exists { + t.Error("Original map was modified") + } +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 5bc0dc25d..13a5d959d 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -170,9 +170,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Create fallback handler wrapper that forwards to ampcode.com when provider not found // Uses lazy evaluation to access proxy (which is created after routes are registered) - fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy { + // Also includes model mapping support for routing unavailable models to alternatives + fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.proxy - }) + }, m.modelMapper) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/api/server.go b/internal/api/server.go index ab9c03548..119df8482 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -150,6 +150,9 @@ type Server struct { // management handler mgmt *managementHandlers.Handler + // ampModule is the Amp routing module for model mapping hot-reload + ampModule *ampmodule.AmpModule + // managementRoutesRegistered tracks whether the management routes have been attached to the engine. managementRoutesRegistered atomic.Bool // managementRoutesEnabled controls whether management endpoints serve real handlers. @@ -268,14 +271,14 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk s.setupRoutes() // Register Amp module using V2 interface with Context - ampModule := ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) + s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) ctx := modules.Context{ Engine: engine, BaseHandler: s.handlers, Config: cfg, AuthMiddleware: AuthMiddleware(accessManager), } - if err := modules.RegisterModule(ctx, ampModule); err != nil { + if err := modules.RegisterModule(ctx, s.ampModule); err != nil { log.Errorf("Failed to register Amp module: %v", err) } @@ -916,11 +919,22 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.mgmt.SetAuthManager(s.handlers.AuthManager) } + // Notify Amp module of config changes (for model mapping hot-reload) + if s.ampModule != nil { + log.Debugf("triggering amp module config update") + if err := s.ampModule.OnConfigUpdated(cfg); err != nil { + log.Errorf("failed to update Amp module config: %v", err) + } + } else { + log.Warnf("amp module is nil, skipping config update") + } + // Count client sources from configuration and auth directory authFiles := util.CountAuthFiles(cfg.AuthDir) geminiAPIKeyCount := len(cfg.GeminiKey) claudeAPIKeyCount := len(cfg.ClaudeKey) codexAPIKeyCount := len(cfg.CodexKey) + vertexAICompatCount := len(cfg.VertexCompatAPIKey) openAICompatCount := 0 for i := range cfg.OpenAICompatibility { entry := cfg.OpenAICompatibility[i] @@ -931,13 +945,14 @@ func (s *Server) UpdateClients(cfg *config.Config) { openAICompatCount += len(entry.APIKeys) } - total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n", + total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount + fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n", total, authFiles, geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, + vertexAICompatCount, openAICompatCount, ) } diff --git a/internal/config/config.go b/internal/config/config.go index 97b5a0c23..11610462b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,6 +37,12 @@ type Config struct { // browser attacks and remote access to management endpoints. Default: true (recommended). AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"` + // AmpModelMappings defines model name mappings for Amp CLI requests. + // When Amp requests a model that isn't available locally, these mappings + // allow routing to an alternative model that IS available. + // Example: Map "claude-opus-4.5" -> "claude-sonnet-4" when opus isn't available. + AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings" json:"amp-model-mappings"` + // AuthDir is the directory where authentication token files are stored. AuthDir string `yaml:"auth-dir" json:"-"` @@ -64,6 +70,10 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers. + // Used for services that use Vertex AI-style paths but with simple API key authentication. + VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. @@ -118,6 +128,18 @@ type QuotaExceeded struct { SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` } +// AmpModelMapping defines a model name mapping for Amp CLI requests. +// When Amp requests a model that isn't available locally, this mapping +// allows routing to an alternative model that IS available. +type AmpModelMapping struct { + // From is the model name that Amp CLI requests (e.g., "claude-opus-4.5"). + From string `yaml:"from" json:"from"` + + // To is the target model name to route to (e.g., "claude-sonnet-4"). + // The target model must have available providers in the registry. + To string `yaml:"to" json:"to"` +} + // PayloadConfig defines default and override parameter rules applied to provider payloads. type PayloadConfig struct { // Default defines rules that only set parameters when they are missing in the payload. @@ -325,6 +347,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Gemini API key configuration and migrate legacy entries. cfg.SanitizeGeminiKeys() + // Sanitize Vertex-compatible API keys: drop entries without base-url + cfg.SanitizeVertexCompatKeys() + // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() @@ -813,6 +838,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool { switch key { case "generative-language-api-key", "gemini-api-key", + "vertex-api-key", "claude-api-key", "codex-api-key", "openai-compatibility": diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go new file mode 100644 index 000000000..1257dd62b --- /dev/null +++ b/internal/config/vertex_compat.go @@ -0,0 +1,84 @@ +package config + +import "strings" + +// VertexCompatKey represents the configuration for Vertex AI-compatible API keys. +// This supports third-party services that use Vertex AI-style endpoint paths +// (/publishers/google/models/{model}:streamGenerateContent) but authenticate +// with simple API keys instead of Google Cloud service account credentials. +// +// Example services: zenmux.ai and similar Vertex-compatible providers. +type VertexCompatKey struct { + // APIKey is the authentication key for accessing the Vertex-compatible API. + // Maps to the x-goog-api-key header. + APIKey string `yaml:"api-key" json:"api-key"` + + // BaseURL is the base URL for the Vertex-compatible API endpoint. + // The executor will append "/v1/publishers/google/models/{model}:action" to this. + // Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..." + BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"` + + // ProxyURL optionally overrides the global proxy for this API key. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // Headers optionally adds extra HTTP headers for requests sent with this key. + // Commonly used for cookies, user-agent, and other authentication headers. + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // Models defines the model configurations including aliases for routing. + Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"` +} + +// VertexCompatModel represents a model configuration for Vertex compatibility, +// including the actual model name and its alias for API routing. +type VertexCompatModel struct { + // Name is the actual model name used by the external provider. + Name string `yaml:"name" json:"name"` + + // Alias is the model name alias that clients will use to reference this model. + Alias string `yaml:"alias" json:"alias"` +} + +// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. +func (cfg *Config) SanitizeVertexCompatKeys() { + if cfg == nil { + return + } + + seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey)) + out := cfg.VertexCompatAPIKey[:0] + for i := range cfg.VertexCompatAPIKey { + entry := cfg.VertexCompatAPIKey[i] + entry.APIKey = strings.TrimSpace(entry.APIKey) + if entry.APIKey == "" { + continue + } + entry.BaseURL = strings.TrimSpace(entry.BaseURL) + if entry.BaseURL == "" { + // BaseURL is required for Vertex API key entries + continue + } + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.Headers = NormalizeHeaders(entry.Headers) + + // Sanitize models: remove entries without valid alias + sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models)) + for _, model := range entry.Models { + model.Alias = strings.TrimSpace(model.Alias) + model.Name = strings.TrimSpace(model.Name) + if model.Alias != "" && model.Name != "" { + sanitizedModels = append(sanitizedModels, model) + } + } + entry.Models = sanitizedModels + + // Use API key + base URL as uniqueness key + uniqueKey := entry.APIKey + "|" + entry.BaseURL + if _, exists := seen[uniqueKey]; exists { + continue + } + seen[uniqueKey] = struct{}{} + out = append(out, entry) + } + cfg.VertexCompatAPIKey = out +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index fd06d44c3..f2718c78d 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -954,7 +954,6 @@ func GetIFlowModels() []*ModelInfo { }{ {ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600}, {ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800}, - {ID: "qwen3-coder", DisplayName: "Qwen3-Coder-480B-A35B", Description: "Qwen3 Coder 480B A35B", Created: 1753228800}, {ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000}, {ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000}, {ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400}, diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index bd4242a11..de4ba072c 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -51,11 +51,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A // Execute handles non-streaming requests. func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return resp, errCreds + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return resp, errCreds + } + return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// ExecuteStream handles SSE streaming for Vertex. +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return nil, errCreds + } + return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// CountTokens calls Vertex countTokens endpoint. +func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return cliproxyexecutor.Response{}, errCreds + } + return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// countTokensWithServiceAccount handles token counting using service account credentials. +func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// countTokensWithAPIKey handles token counting using API key credentials. +func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// Refresh is a no-op for service account based credentials. +func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + return auth, nil +} +// executeWithServiceAccount handles authentication using service account credentials. +// This method contains the original service account authentication logic. +func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -149,13 +376,104 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return resp, nil } -// ExecuteStream handles SSE streaming for Vertex. -func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return nil, errCreds +// executeWithAPIKey handles authentication using API key credentials. +func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + + action := "generateContent" + if req.Metadata != nil { + if a, _ := req.Metadata["action"].(string); a == "countTokens" { + action = "countTokens" + } + } + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") + + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if errNewReq != nil { + return resp, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return resp, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead } + appendAPIResponseChunk(ctx, e.cfg, data) + reporter.publish(ctx, parseGeminiUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} +// executeStreamWithServiceAccount handles streaming authentication using service account credentials. +func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -266,42 +584,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return stream, nil } -// CountTokens calls Vertex countTokens endpoint. -func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return cliproxyexecutor.Response{}, errCreds - } +// executeStreamWithAPIKey handles streaming authentication using API key credentials. +func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } - translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) - baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") - httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if errNewReq != nil { - return cliproxyexecutor.Response{}, errNewReq + return nil, errNewReq } httpReq.Header.Set("Content-Type", "application/json") - if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } else if errTok != nil { - log.Errorf("vertex executor: access token error: %v", errTok) - return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) } applyGeminiHeaders(httpReq, auth) @@ -315,7 +635,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau URL: url, Method: http.MethodPost, Headers: httpReq.Header.Clone(), - Body: translatedReq, + Body: body, Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -327,38 +647,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { recordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo + return nil, errDo } - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("vertex executor: close response body error: %v", errClose) - } - }() recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - data, errRead := io.ReadAll(httpResp.Body) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} } - count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil -} -// Refresh is a no-op for service account based credentials. -func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - return auth, nil + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 20_971_520) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseGeminiStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + }() + return stream, nil } // vertexCreds extracts project, location and raw service account JSON from auth metadata. @@ -401,6 +736,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou return projectID, location, saJSON, nil } +// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern. +func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { + if a == nil { + return "", "" + } + if a.Attributes != nil { + apiKey = a.Attributes["api_key"] + baseURL = a.Attributes["base_url"] + } + if apiKey == "" && a.Metadata != nil { + if v, ok := a.Metadata["access_token"].(string); ok { + apiKey = v + } + } + return +} + func vertexBaseURL(location string) string { loc := strings.TrimSpace(location) if loc == "" { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a284541ae..02abb7433 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -162,12 +162,14 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config)) // Start begins watching the configuration file and authentication directory func (w *Watcher) Start(ctx context.Context) error { - // Watch the config file - if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil { - log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig) + // Watch the config file's parent directory instead of the file itself. + // This handles editors that use atomic save (write to temp, then rename). + configDir := filepath.Dir(w.configPath) + if errAddConfig := w.watcher.Add(configDir); errAddConfig != nil { + log.Errorf("failed to watch config directory %s: %v", configDir, errAddConfig) return errAddConfig } - log.Debugf("watching config file: %s", w.configPath) + log.Debugf("watching config directory: %s (for file: %s)", configDir, filepath.Base(w.configPath)) // Watch the auth directory if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil { @@ -496,6 +498,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str return hex.EncodeToString(sum[:]) } +func computeVertexCompatModelsHash(models []config.VertexCompatModel) string { + if len(models) == 0 { + return "" + } + data, err := json.Marshal(models) + if err != nil || len(data) == 0 { + return "" + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + // computeClaudeModelsHash returns a stable hash for Claude model aliases. func computeClaudeModelsHash(models []config.ClaudeModel) string { if len(models) == 0 { @@ -700,7 +714,23 @@ func (w *Watcher) isKnownAuthFile(path string) bool { func (w *Watcher) handleEvent(event fsnotify.Event) { // Filter only relevant events: config file or auth-dir JSON files. configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename - isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0 + // Check if this event is for our config file (handle both exact match and basename match for directory watching) + isConfigEvent := false + if event.Op&configOps != 0 { + // Exact path match + if event.Name == w.configPath { + isConfigEvent = true + } else { + // Check if basename matches and it's in the config directory (for atomic save detection) + configDir := filepath.Dir(w.configPath) + configBase := filepath.Base(w.configPath) + eventDir := filepath.Dir(event.Name) + eventBase := filepath.Base(event.Name) + if eventDir == configDir && eventBase == configBase { + isConfigEvent = true + } + } + } authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0 if !isConfigEvent && !isAuthJSON { @@ -902,8 +932,8 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string // no legacy clients to unregister // Create new API key clients based on the new config - geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) - totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) + totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount log.Debugf("loaded %d API key clients", totalAPIKeyClients) var authFileCount int @@ -946,7 +976,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.clientsMutex.Unlock() } - totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount // Ensure consumers observe the new configuration before auth updates dispatch. if w.reloadCallback != nil { @@ -956,10 +986,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.refreshAuthState() - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", + log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", totalNewClients, authFileCount, geminiAPIKeyCount, + vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount, @@ -1074,6 +1105,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey") out = append(out, a) } + // Claude API keys -> synthesize auths for i := range cfg.ClaudeKey { ck := cfg.ClaudeKey[i] @@ -1240,6 +1272,43 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } } } + + // Process Vertex API key providers (Vertex-compatible endpoints) + for i := range cfg.VertexCompatAPIKey { + compat := &cfg.VertexCompatAPIKey[i] + providerName := "vertex" + base := strings.TrimSpace(compat.BaseURL) + + key := strings.TrimSpace(compat.APIKey) + proxyURL := strings.TrimSpace(compat.ProxyURL) + idKind := fmt.Sprintf("vertex:apikey:%s", base) + id, token := idGen.next(idKind, key, base, proxyURL) + attrs := map[string]string{ + "source": fmt.Sprintf("config:vertex-apikey[%s]", token), + "base_url": base, + "provider_key": providerName, + } + if key != "" { + attrs["api_key"] = key + } + if hash := computeVertexCompatModelsHash(compat.Models); hash != "" { + attrs["models_hash"] = hash + } + addConfigHeadersToAttrs(compat.Headers, attrs) + a := &coreauth.Auth{ + ID: id, + Provider: providerName, + Label: "vertex-apikey", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + applyAuthExcludedModelsMeta(a, cfg, nil, "apikey") + out = append(out, a) + } + // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) entries, _ := os.ReadDir(w.authDir) for _, e := range entries { @@ -1456,8 +1525,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int { return authFileCount } -func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { +func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { geminiAPIKeyCount := 0 + vertexCompatAPIKeyCount := 0 claudeAPIKeyCount := 0 codexAPIKeyCount := 0 openAICompatCount := 0 @@ -1466,6 +1536,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { // Stateless executor handles Gemini API keys; avoid constructing legacy clients. geminiAPIKeyCount += len(cfg.GeminiKey) } + if len(cfg.VertexCompatAPIKey) > 0 { + vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey) + } if len(cfg.ClaudeKey) > 0 { claudeAPIKeyCount += len(cfg.ClaudeKey) } @@ -1483,7 +1556,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { } } } - return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount + return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount } func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { diff --git a/sdk/cliproxy/providers.go b/sdk/cliproxy/providers.go index a5810336a..401885f5c 100644 --- a/sdk/cliproxy/providers.go +++ b/sdk/cliproxy/providers.go @@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider { type apiKeyClientProvider struct{} func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) { - geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) + geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) if ctx != nil { select { case <-ctx.Done(): @@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A } } return &APIKeyClientResult{ - GeminiKeyCount: geminiCount, - ClaudeKeyCount: claudeCount, - CodexKeyCount: codexCount, - OpenAICompatCount: openAICompat, + GeminiKeyCount: geminiCount, + VertexCompatKeyCount: vertexCompatCount, + ClaudeKeyCount: claudeCount, + CodexKeyCount: codexCount, + OpenAICompatCount: openAICompat, }, nil } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 4001c49c6..549f35c24 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -324,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName if len(a.Attributes) > 0 { providerKey = strings.TrimSpace(a.Attributes["provider_key"]) compatName = strings.TrimSpace(a.Attributes["compat_name"]) - if providerKey != "" || compatName != "" { + if compatName != "" { if providerKey == "" { providerKey = compatName } @@ -500,7 +500,7 @@ func (s *Service) Run(ctx context.Context) error { }() time.Sleep(100 * time.Millisecond) - fmt.Println("API server started successfully") + fmt.Printf("API server started successfully on: %d\n", s.cfg.Port) if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) @@ -681,6 +681,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "vertex": // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() + if authKind == "apikey" { + if entry := s.resolveConfigVertexCompatKey(a); entry != nil && len(entry.Models) > 0 { + models = buildVertexCompatConfigModels(entry) + } + } models = applyExcludedModels(models, excluded) case "gemini-cli": models = registry.GetGeminiCLIModels() @@ -878,6 +883,40 @@ func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey return nil } +func (s *Service) resolveConfigVertexCompatKey(auth *coreauth.Auth) *config.VertexCompatKey { + if auth == nil || s.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range s.cfg.VertexCompatAPIKey { + entry := &s.cfg.VertexCompatAPIKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range s.cfg.VertexCompatAPIKey { + entry := &s.cfg.VertexCompatAPIKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} + func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey { if auth == nil || s.cfg == nil { return nil @@ -996,6 +1035,44 @@ func matchWildcard(pattern, value string) bool { return true } +func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { + if entry == nil || len(entry.Models) == 0 { + return nil + } + now := time.Now().Unix() + out := make([]*ModelInfo, 0, len(entry.Models)) + seen := make(map[string]struct{}, len(entry.Models)) + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if alias == "" { + alias = name + } + if alias == "" { + continue + } + key := strings.ToLower(alias) + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + display := name + if display == "" { + display = alias + } + out = append(out, &ModelInfo{ + ID: alias, + Object: "model", + Created: now, + OwnedBy: "vertex", + Type: "vertex", + DisplayName: display, + }) + } + return out +} + func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { if entry == nil || len(entry.Models) == 0 { return nil diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index b44185d17..42c7c4881 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -49,19 +49,21 @@ type APIKeyClientProvider interface { Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) } -// APIKeyClientResult contains API key based clients along with type counts. -// It provides metadata about the number of clients loaded for each provider type. +// APIKeyClientResult is returned by APIKeyClientProvider.Load() type APIKeyClientResult struct { - // GeminiKeyCount is the number of Gemini API key clients loaded. + // GeminiKeyCount is the number of Gemini API keys loaded GeminiKeyCount int - // ClaudeKeyCount is the number of Claude API key clients loaded. + // VertexCompatKeyCount is the number of Vertex-compatible API keys loaded + VertexCompatKeyCount int + + // ClaudeKeyCount is the number of Claude API keys loaded ClaudeKeyCount int - // CodexKeyCount is the number of Codex API key clients loaded. + // CodexKeyCount is the number of Codex API keys loaded CodexKeyCount int - // OpenAICompatCount is the number of OpenAI-compatible API key clients loaded. + // OpenAICompatCount is the number of OpenAI compatibility API keys loaded OpenAICompatCount int }