From 2cd59806e2c730964c6b6d2bc42eea71140ac412 Mon Sep 17 00:00:00 2001 From: Trung Nguyen Date: Sat, 29 Nov 2025 12:44:09 +0700 Subject: [PATCH 1/6] feat(amp): add model mapping support for routing unavailable models to alternatives - Add AmpModelMapping config to route models like 'claude-opus-4.5' to 'claude-sonnet-4' - Add ModelMapper interface and DefaultModelMapper implementation with hot-reload support - Enhance FallbackHandler to apply model mappings before falling back to ampcode.com - Add structured logging for routing decisions (local provider, mapping, amp credits) - Update config.example.yaml with amp-model-mappings documentation --- config.example.yaml | 22 +++ internal/api/modules/amp/amp.go | 17 +- internal/api/modules/amp/fallback_handlers.go | 142 ++++++++++++- internal/api/modules/amp/model_mapping.go | 113 +++++++++++ .../api/modules/amp/model_mapping_test.go | 186 ++++++++++++++++++ internal/api/modules/amp/routes.go | 5 +- internal/config/config.go | 18 ++ 7 files changed, 495 insertions(+), 8 deletions(-) create mode 100644 internal/api/modules/amp/model_mapping.go create mode 100644 internal/api/modules/amp/model_mapping_test.go diff --git a/config.example.yaml b/config.example.yaml index 8457e1037..71863d492 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -55,6 +55,28 @@ quota-exceeded: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false +# Amp CLI Integration +# Configure upstream URL for Amp CLI OAuth and management features +#amp-upstream-url: "https://ampcode.com" + +# Optional: Override API key for Amp upstream (otherwise uses env or file) +#amp-upstream-api-key: "" + +# Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended) +#amp-restrict-management-to-localhost: true + +# Amp Model Mappings +# Route unavailable Amp models to alternative models available in your local proxy. +# Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) +# but you have a similar model available (e.g., Claude Sonnet 4). +#amp-model-mappings: +# - from: "claude-opus-4.5" # Model requested by Amp CLI +# to: "claude-sonnet-4" # Route to this available model instead +# - from: "gpt-5" +# to: "gemini-2.5-pro" +# - from: "claude-3-opus-20240229" +# to: "claude-3-5-sonnet-20241022" + # Gemini API keys (preferred) #gemini-api-key: # - api-key: "AIzaSy...01" diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index 0086d1798..b5a139f6e 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -23,11 +23,13 @@ type Option func(*AmpModule) // - Reverse proxy to Amp control plane for OAuth/management // - Provider-specific route aliases (/api/provider/{provider}/...) // - Automatic gzip decompression for misconfigured upstreams +// - Model mapping for routing unavailable models to alternatives type AmpModule struct { secretSource SecretSource proxy *httputil.ReverseProxy accessManager *sdkaccess.Manager authMiddleware_ gin.HandlerFunc + modelMapper *DefaultModelMapper enabled bool registerOnce sync.Once } @@ -101,6 +103,9 @@ func (m *AmpModule) Register(ctx modules.Context) error { // Use registerOnce to ensure routes are only registered once var regErr error m.registerOnce.Do(func() { + // Initialize model mapper from config (for routing unavailable models to alternatives) + m.modelMapper = NewModelMapper(ctx.Config.AmpModelMappings) + // Always register provider aliases - these work without an upstream m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) @@ -159,8 +164,13 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { // OnConfigUpdated handles configuration updates. // Currently requires restart for URL changes (could be enhanced for dynamic updates). func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { + // Update model mappings (hot-reload supported) + if m.modelMapper != nil { + m.modelMapper.UpdateMappings(cfg.AmpModelMappings) + } + if !m.enabled { - log.Debug("Amp routing not enabled, skipping config update") + log.Debug("Amp routing not enabled, skipping other config updates") return nil } @@ -181,3 +191,8 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { log.Debug("Amp config updated (restart required for URL changes)") return nil } + +// GetModelMapper returns the model mapper instance (for testing/debugging). +func (m *AmpModule) GetModelMapper() *DefaultModelMapper { + return m.modelMapper +} diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index e7b28986d..17c60708c 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -6,16 +6,75 @@ import ( "io" "net/http/httputil" "strings" + "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) +// AmpRouteType represents the type of routing decision made for an Amp request +type AmpRouteType string + +const ( + // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free) + RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER" + // RouteTypeModelMapping indicates the request was remapped to another available model (free) + RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING" + // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits) + RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS" + // RouteTypeNoProvider indicates no provider or fallback available + RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" +) + +// logAmpRouting logs the routing decision for an Amp request with structured fields +func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { + fields := log.Fields{ + "component": "amp-routing", + "route_type": string(routeType), + "requested_model": requestedModel, + "path": path, + "timestamp": time.Now().Format(time.RFC3339), + } + + if resolvedModel != "" && resolvedModel != requestedModel { + fields["resolved_model"] = resolvedModel + } + if provider != "" { + fields["provider"] = provider + } + + switch routeType { + case RouteTypeLocalProvider: + fields["cost"] = "free" + fields["source"] = "local_oauth" + log.WithFields(fields).Infof("[AMP] Using local provider for model: %s", requestedModel) + + case RouteTypeModelMapping: + fields["cost"] = "free" + fields["source"] = "local_oauth" + fields["mapping"] = requestedModel + " -> " + resolvedModel + log.WithFields(fields).Infof("[AMP] Model mapped: %s -> %s", requestedModel, resolvedModel) + + case RouteTypeAmpCredits: + fields["cost"] = "amp_credits" + fields["source"] = "ampcode.com" + fields["model_id"] = requestedModel // Explicit model_id for easy config reference + log.WithFields(fields).Warnf("[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"\"}]", requestedModel, requestedModel) + + case RouteTypeNoProvider: + fields["cost"] = "none" + fields["source"] = "error" + fields["model_id"] = requestedModel // Explicit model_id for easy config reference + log.WithFields(fields).Warnf("[AMP] No provider available for model_id: %s", requestedModel) + } +} + // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper } // NewFallbackHandler creates a new fallback handler wrapper @@ -26,10 +85,25 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler } } +// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { + return &FallbackHandler{ + getProxy: getProxy, + modelMapper: mapper, + } +} + +// SetModelMapper sets the model mapper for this handler (allows late binding) +func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { + fh.modelMapper = mapper +} + // WrapHandler wraps a gin.HandlerFunc with fallback logic // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { + requestPath := c.Request.URL.Path + // Read the request body to extract the model name bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { @@ -55,12 +129,33 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Check if we have providers for this model providers := util.GetProviderName(normalizedModel) + // Track resolved model for logging (may change if mapping is applied) + resolvedModel := normalizedModel + usedMapping := false + if len(providers) == 0 { - // No providers configured - check if we have a proxy for fallback + // No providers configured - check if we have a model mapping + if fh.modelMapper != nil { + if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { + // Mapping found - rewrite the model in request body + bodyBytes = rewriteModelInBody(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + resolvedModel = mappedModel + usedMapping = true + + // Get providers for the mapped model + providers = util.GetProviderName(mappedModel) + + // Continue to handler with remapped model + goto handleRequest + } + } + + // No mapping found - check if we have a proxy for fallback proxy := fh.getProxy() if proxy != nil { - // Fallback to ampcode.com - log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName) + // Log: Forwarding to ampcode.com (uses Amp credits) + logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath) // Restore body again for the proxy c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) @@ -71,7 +166,23 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } // No proxy available, let the normal handler return the error - log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName) + logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) + } + + handleRequest: + + // Log the routing decision + providerName := "" + if len(providers) > 0 { + providerName = providers[0] + } + + if usedMapping { + // Log: Model was mapped to another model + logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) + } else if len(providers) > 0 { + // Log: Using local provider (free) + logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) } // Providers available or no proxy for fallback, restore body and use normal handler @@ -91,6 +202,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } } +// rewriteModelInBody replaces the model name in a JSON request body +func rewriteModelInBody(body []byte, newModel string) []byte { + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err) + return body + } + + if _, exists := payload["model"]; exists { + payload["model"] = newModel + newBody, err := json.Marshal(payload) + if err != nil { + log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err) + return body + } + return newBody + } + + return body +} + // extractModelFromRequest attempts to extract the model name from various request formats func extractModelFromRequest(body []byte, c *gin.Context) string { // First try to parse from JSON body (OpenAI, Claude, etc.) diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go new file mode 100644 index 000000000..c07f41c40 --- /dev/null +++ b/internal/api/modules/amp/model_mapping.go @@ -0,0 +1,113 @@ +// Package amp provides model mapping functionality for routing Amp CLI requests +// to alternative models when the requested model is not available locally. +package amp + +import ( + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// ModelMapper provides model name mapping/aliasing for Amp CLI requests. +// When an Amp request comes in for a model that isn't available locally, +// this mapper can redirect it to an alternative model that IS available. +type ModelMapper interface { + // MapModel returns the target model name if a mapping exists and the target + // model has available providers. Returns empty string if no mapping applies. + MapModel(requestedModel string) string + + // UpdateMappings refreshes the mapping configuration (for hot-reload). + UpdateMappings(mappings []config.AmpModelMapping) +} + +// DefaultModelMapper implements ModelMapper with thread-safe mapping storage. +type DefaultModelMapper struct { + mu sync.RWMutex + mappings map[string]string // from -> to (normalized lowercase keys) +} + +// NewModelMapper creates a new model mapper with the given initial mappings. +func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { + m := &DefaultModelMapper{ + mappings: make(map[string]string), + } + m.UpdateMappings(mappings) + return m +} + +// MapModel checks if a mapping exists for the requested model and if the +// target model has available local providers. Returns the mapped model name +// or empty string if no valid mapping exists. +func (m *DefaultModelMapper) MapModel(requestedModel string) string { + if requestedModel == "" { + return "" + } + + m.mu.RLock() + defer m.mu.RUnlock() + + // Normalize the requested model for lookup + normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel)) + + // Check for direct mapping + targetModel, exists := m.mappings[normalizedRequest] + if !exists { + return "" + } + + // Verify target model has available providers + providers := util.GetProviderName(targetModel) + if len(providers) == 0 { + log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) + return "" + } + + // Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go + log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel) + return targetModel +} + +// UpdateMappings refreshes the mapping configuration from config. +// This is called during initialization and on config hot-reload. +func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear and rebuild mappings + m.mappings = make(map[string]string, len(mappings)) + + for _, mapping := range mappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + + if from == "" || to == "" { + log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to) + continue + } + + // Store with normalized lowercase key for case-insensitive lookup + normalizedFrom := strings.ToLower(from) + m.mappings[normalizedFrom] = to + + log.Debugf("amp model mapping registered: %s -> %s", from, to) + } + + if len(m.mappings) > 0 { + log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) + } +} + +// GetMappings returns a copy of current mappings (for debugging/status). +func (m *DefaultModelMapper) GetMappings() map[string]string { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]string, len(m.mappings)) + for k, v := range m.mappings { + result[k] = v + } + return result +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go new file mode 100644 index 000000000..c11d61bd7 --- /dev/null +++ b/internal/api/modules/amp/model_mapping_test.go @@ -0,0 +1,186 @@ +package amp + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestNewModelMapper(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + {From: "gpt-5", To: "gemini-2.5-pro"}, + } + + mapper := NewModelMapper(mappings) + if mapper == nil { + t.Fatal("Expected non-nil mapper") + } + + result := mapper.GetMappings() + if len(result) != 2 { + t.Errorf("Expected 2 mappings, got %d", len(result)) + } +} + +func TestNewModelMapper_Empty(t *testing.T) { + mapper := NewModelMapper(nil) + if mapper == nil { + t.Fatal("Expected non-nil mapper") + } + + result := mapper.GetMappings() + if len(result) != 0 { + t.Errorf("Expected 0 mappings, got %d", len(result)) + } +} + +func TestModelMapper_MapModel_NoProvider(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Without a registered provider for the target, mapping should return empty + result := mapper.MapModel("claude-opus-4.5") + if result != "" { + t.Errorf("Expected empty result when target has no provider, got %s", result) + } +} + +func TestModelMapper_MapModel_WithProvider(t *testing.T) { + // Register a mock provider for the target model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client") + + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // With a registered provider, mapping should work + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client2") + + mappings := []config.AmpModelMapping{ + {From: "Claude-Opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Should match case-insensitively + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_MapModel_NotFound(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + // Unknown model should return empty + result := mapper.MapModel("unknown-model") + if result != "" { + t.Errorf("Expected empty for unknown model, got %s", result) + } +} + +func TestModelMapper_MapModel_EmptyInput(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "claude-opus-4.5", To: "claude-sonnet-4"}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("") + if result != "" { + t.Errorf("Expected empty for empty input, got %s", result) + } +} + +func TestModelMapper_UpdateMappings(t *testing.T) { + mapper := NewModelMapper(nil) + + // Initially empty + if len(mapper.GetMappings()) != 0 { + t.Error("Expected 0 initial mappings") + } + + // Update with new mappings + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "model-a", To: "model-b"}, + {From: "model-c", To: "model-d"}, + }) + + result := mapper.GetMappings() + if len(result) != 2 { + t.Errorf("Expected 2 mappings after update, got %d", len(result)) + } + + // Update again should replace, not append + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "model-x", To: "model-y"}, + }) + + result = mapper.GetMappings() + if len(result) != 1 { + t.Errorf("Expected 1 mapping after second update, got %d", len(result)) + } +} + +func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) { + mapper := NewModelMapper(nil) + + mapper.UpdateMappings([]config.AmpModelMapping{ + {From: "", To: "model-b"}, // Invalid: empty from + {From: "model-a", To: ""}, // Invalid: empty to + {From: " ", To: "model-b"}, // Invalid: whitespace from + {From: "model-c", To: "model-d"}, // Valid + }) + + result := mapper.GetMappings() + if len(result) != 1 { + t.Errorf("Expected 1 valid mapping, got %d", len(result)) + } +} + +func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) { + mappings := []config.AmpModelMapping{ + {From: "model-a", To: "model-b"}, + } + + mapper := NewModelMapper(mappings) + + // Get mappings and modify the returned map + result := mapper.GetMappings() + result["new-key"] = "new-value" + + // Original should be unchanged + original := mapper.GetMappings() + if len(original) != 1 { + t.Errorf("Expected original to have 1 mapping, got %d", len(original)) + } + if _, exists := original["new-key"]; exists { + t.Error("Original map was modified") + } +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 8e5189adb..8bd739bb9 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -162,9 +162,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Create fallback handler wrapper that forwards to ampcode.com when provider not found // Uses lazy evaluation to access proxy (which is created after routes are registered) - fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy { + // Also includes model mapping support for routing unavailable models to alternatives + fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.proxy - }) + }, m.modelMapper) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/config/config.go b/internal/config/config.go index 319200752..8612b3e5f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,6 +37,12 @@ type Config struct { // browser attacks and remote access to management endpoints. Default: true (recommended). AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"` + // AmpModelMappings defines model name mappings for Amp CLI requests. + // When Amp requests a model that isn't available locally, these mappings + // allow routing to an alternative model that IS available. + // Example: Map "claude-opus-4.5" -> "claude-sonnet-4" when opus isn't available. + AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings" json:"amp-model-mappings"` + // AuthDir is the directory where authentication token files are stored. AuthDir string `yaml:"auth-dir" json:"-"` @@ -115,6 +121,18 @@ type QuotaExceeded struct { SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` } +// AmpModelMapping defines a model name mapping for Amp CLI requests. +// When Amp requests a model that isn't available locally, this mapping +// allows routing to an alternative model that IS available. +type AmpModelMapping struct { + // From is the model name that Amp CLI requests (e.g., "claude-opus-4.5"). + From string `yaml:"from" json:"from"` + + // To is the target model name to route to (e.g., "claude-sonnet-4"). + // The target model must have available providers in the registry. + To string `yaml:"to" json:"to"` +} + // PayloadConfig defines default and override parameter rules applied to provider payloads. type PayloadConfig struct { // Default defines rules that only set parameters when they are missing in the payload. From 33a5656235842975e0f53bd55f14d342dfb5604a Mon Sep 17 00:00:00 2001 From: Trung Nguyen Date: Sat, 29 Nov 2025 12:51:03 +0700 Subject: [PATCH 2/6] docs: add model mapping documentation for Amp CLI integration - Add model mapping feature to README.md Amp CLI section - Add detailed Model Mapping Configuration section to amp-cli-integration.md - Update architecture diagram to show model mapping flow - Update Model Fallback Behavior to include mapping step - Add Table of Contents entry for model mapping --- README.md | 1 + docs/amp-cli-integration.md | 57 +++++++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 90d5d4650..7a9b05909 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A - Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`) - Management proxy for OAuth authentication and account features - Smart model fallback with automatic routing +- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`) - Security-first design with localhost-only management endpoints **→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)** diff --git a/docs/amp-cli-integration.md b/docs/amp-cli-integration.md index 41cea66a2..feac51fc4 100644 --- a/docs/amp-cli-integration.md +++ b/docs/amp-cli-integration.md @@ -8,6 +8,7 @@ This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, - [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate) - [Architecture](#architecture) - [Configuration](#configuration) + - [Model Mapping Configuration](#model-mapping-configuration) - [Setup](#setup) - [Usage](#usage) - [Troubleshooting](#troubleshooting) @@ -21,6 +22,7 @@ The Amp CLI integration adds specialized routing to support Amp's API patterns w - **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers - **Management proxy**: Forwards OAuth and account management requests to Amp's control plane - **Smart fallback**: Automatically routes unconfigured models to ampcode.com +- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`) - **Secret management**: Configurable precedence (config > env > file) with 5-minute caching - **Security-first**: Management routes restricted to localhost by default - **Automatic gzip handling**: Decompresses responses from Amp upstream @@ -75,7 +77,10 @@ Amp CLI/IDE │ ↓ │ ├─ Model configured locally? │ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers) - │ │ NO → Forward to ampcode.com (reverse proxy) + │ │ NO ↓ + │ │ ├─ Model mapping configured? + │ │ │ YES → Rewrite model → Use local handler (free) + │ │ │ NO → Forward to ampcode.com (uses Amp credits) │ ↓ │ Response │ @@ -115,6 +120,49 @@ amp-upstream-url: "https://ampcode.com" amp-restrict-management-to-localhost: true ``` +### Model Mapping Configuration + +When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally. + +```yaml +# Route unavailable models to alternatives +amp-model-mappings: + # Example: Route Claude Opus 4.5 requests to Claude Sonnet 4 + - from: "claude-opus-4.5" + to: "claude-sonnet-4" + + # Example: Route GPT-5 requests to Gemini 2.5 Pro + - from: "gpt-5" + to: "gemini-2.5-pro" + + # Example: Map older model names to newer versions + - from: "claude-3-opus-20240229" + to: "claude-3-5-sonnet-20241022" +``` + +**How it works:** + +1. Amp CLI requests a model (e.g., `claude-opus-4.5`) +2. CLIProxyAPI checks if a local provider is available for that model +3. If not available, it checks the model mappings +4. If a mapping exists, the request is rewritten to use the target model +5. The request is then handled locally (free, using your OAuth subscription) + +**Benefits:** +- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com +- **Hot-reload**: Mappings can be updated without restarting the proxy +- **Structured logging**: Clear logs show when mappings are applied + +**Routing Decision Logs:** + +The proxy logs each routing decision with structured fields: + +``` +[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free) +[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free) +[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits) +``` + ### Secret Resolution Precedence The Amp module resolves API keys using this precedence order: @@ -301,11 +349,14 @@ When Amp requests a model: 1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider? 2. **If YES**: Route to local handler (use your OAuth subscription) -3. **If NO**: Forward to ampcode.com (use Amp's default routing) +3. **If NO**: Check if a model mapping exists +4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free) +5. **If no mapping**: Forward to ampcode.com (uses Amp credits) This enables seamless mixed usage: - Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions -- Models you haven't configured → Amp's default providers +- Models with mappings configured → Routed to alternative local models (free) +- Models you haven't configured and have no mapping → Amp's default providers (uses credits) ### Example API Calls From 3409f4e336803eaf211346b22d8e0d835b1ffe8a Mon Sep 17 00:00:00 2001 From: NguyenSiTrung Date: Mon, 1 Dec 2025 13:34:49 +0700 Subject: [PATCH 3/6] fix: enable hot reload for amp-model-mappings config - Store ampModule in Server struct to access it during config updates - Call ampModule.OnConfigUpdated() in UpdateClients() for hot reload - Watch config directory instead of file to handle atomic saves (vim, VSCode, etc.) - Improve config file event detection with basename matching - Add diagnostic logging for config reload tracing --- internal/api/modules/amp/amp.go | 3 +++ internal/api/server.go | 17 +++++++++++++++-- internal/watcher/watcher.go | 28 +++++++++++++++++++++++----- 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index b5a139f6e..281fda656 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -166,7 +166,10 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { // Update model mappings (hot-reload supported) if m.modelMapper != nil { + log.Infof("amp config updated: reloading %d model mapping(s)", len(cfg.AmpModelMappings)) m.modelMapper.UpdateMappings(cfg.AmpModelMappings) + } else { + log.Warnf("amp model mapper not initialized, skipping model mapping update") } if !m.enabled { diff --git a/internal/api/server.go b/internal/api/server.go index ab9c03548..fb3610c2b 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -150,6 +150,9 @@ type Server struct { // management handler mgmt *managementHandlers.Handler + // ampModule is the Amp routing module for model mapping hot-reload + ampModule *ampmodule.AmpModule + // managementRoutesRegistered tracks whether the management routes have been attached to the engine. managementRoutesRegistered atomic.Bool // managementRoutesEnabled controls whether management endpoints serve real handlers. @@ -268,14 +271,14 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk s.setupRoutes() // Register Amp module using V2 interface with Context - ampModule := ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) + s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) ctx := modules.Context{ Engine: engine, BaseHandler: s.handlers, Config: cfg, AuthMiddleware: AuthMiddleware(accessManager), } - if err := modules.RegisterModule(ctx, ampModule); err != nil { + if err := modules.RegisterModule(ctx, s.ampModule); err != nil { log.Errorf("Failed to register Amp module: %v", err) } @@ -916,6 +919,16 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.mgmt.SetAuthManager(s.handlers.AuthManager) } + // Notify Amp module of config changes (for model mapping hot-reload) + if s.ampModule != nil { + log.Debugf("triggering amp module config update") + if err := s.ampModule.OnConfigUpdated(cfg); err != nil { + log.Errorf("failed to update Amp module config: %v", err) + } + } else { + log.Warnf("amp module is nil, skipping config update") + } + // Count client sources from configuration and auth directory authFiles := util.CountAuthFiles(cfg.AuthDir) geminiAPIKeyCount := len(cfg.GeminiKey) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a284541ae..af72212d0 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -162,12 +162,14 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config)) // Start begins watching the configuration file and authentication directory func (w *Watcher) Start(ctx context.Context) error { - // Watch the config file - if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil { - log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig) + // Watch the config file's parent directory instead of the file itself. + // This handles editors that use atomic save (write to temp, then rename). + configDir := filepath.Dir(w.configPath) + if errAddConfig := w.watcher.Add(configDir); errAddConfig != nil { + log.Errorf("failed to watch config directory %s: %v", configDir, errAddConfig) return errAddConfig } - log.Debugf("watching config file: %s", w.configPath) + log.Debugf("watching config directory: %s (for file: %s)", configDir, filepath.Base(w.configPath)) // Watch the auth directory if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil { @@ -700,7 +702,23 @@ func (w *Watcher) isKnownAuthFile(path string) bool { func (w *Watcher) handleEvent(event fsnotify.Event) { // Filter only relevant events: config file or auth-dir JSON files. configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename - isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0 + // Check if this event is for our config file (handle both exact match and basename match for directory watching) + isConfigEvent := false + if event.Op&configOps != 0 { + // Exact path match + if event.Name == w.configPath { + isConfigEvent = true + } else { + // Check if basename matches and it's in the config directory (for atomic save detection) + configDir := filepath.Dir(w.configPath) + configBase := filepath.Base(w.configPath) + eventDir := filepath.Dir(event.Name) + eventBase := filepath.Base(event.Name) + if eventDir == configDir && eventBase == configBase { + isConfigEvent = true + } + } + } authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0 if !isConfigEvent && !isAuthJSON { From 0ebb6540190efbf0da99c4d6308302527960dbe7 Mon Sep 17 00:00:00 2001 From: Aero Date: Tue, 2 Dec 2025 08:14:22 +0800 Subject: [PATCH 4/6] feat: Add support for VertexAI compatible service (#375) feat: consolidate Vertex AI compatibility with API key support in Gemini --- .gitignore | 5 + internal/api/server.go | 6 +- internal/config/config.go | 8 + internal/config/vertex_compat.go | 84 +++ .../executor/gemini_vertex_executor.go | 481 ++++++++++++++++-- internal/watcher/watcher.go | 66 ++- sdk/cliproxy/providers.go | 11 +- sdk/cliproxy/service.go | 35 +- sdk/cliproxy/types.go | 14 +- 9 files changed, 633 insertions(+), 77 deletions(-) create mode 100644 internal/config/vertex_compat.go diff --git a/.gitignore b/.gitignore index ef2d935ae..9e730c987 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ pgstore/* gitstore/* objectstore/* static/* +refs/* # Authentication data auths/* @@ -30,3 +31,7 @@ GEMINI.md .vscode/* .claude/* .serena/* + +# macOS +.DS_Store +._* diff --git a/internal/api/server.go b/internal/api/server.go index fb3610c2b..119df8482 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -934,6 +934,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { geminiAPIKeyCount := len(cfg.GeminiKey) claudeAPIKeyCount := len(cfg.ClaudeKey) codexAPIKeyCount := len(cfg.CodexKey) + vertexAICompatCount := len(cfg.VertexCompatAPIKey) openAICompatCount := 0 for i := range cfg.OpenAICompatibility { entry := cfg.OpenAICompatibility[i] @@ -944,13 +945,14 @@ func (s *Server) UpdateClients(cfg *config.Config) { openAICompatCount += len(entry.APIKeys) } - total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n", + total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount + fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n", total, authFiles, geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, + vertexAICompatCount, openAICompatCount, ) } diff --git a/internal/config/config.go b/internal/config/config.go index 9111a8a8a..11610462b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -70,6 +70,10 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers. + // Used for services that use Vertex AI-style paths but with simple API key authentication. + VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. @@ -343,6 +347,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Gemini API key configuration and migrate legacy entries. cfg.SanitizeGeminiKeys() + // Sanitize Vertex-compatible API keys: drop entries without base-url + cfg.SanitizeVertexCompatKeys() + // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() @@ -831,6 +838,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool { switch key { case "generative-language-api-key", "gemini-api-key", + "vertex-api-key", "claude-api-key", "codex-api-key", "openai-compatibility": diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go new file mode 100644 index 000000000..a8d94ccb8 --- /dev/null +++ b/internal/config/vertex_compat.go @@ -0,0 +1,84 @@ +package config + +import "strings" + +// VertexCompatKey represents the configuration for Vertex AI-compatible API keys. +// This supports third-party services that use Vertex AI-style endpoint paths +// (/publishers/google/models/{model}:streamGenerateContent) but authenticate +// with simple API keys instead of Google Cloud service account credentials. +// +// Example services: zenmux.ai and similar Vertex-compatible providers. +type VertexCompatKey struct { + // APIKey is the authentication key for accessing the Vertex-compatible API. + // Maps to the x-goog-api-key header. + APIKey string `yaml:"api-key" json:"api-key"` + + // BaseURL is the base URL for the Vertex-compatible API endpoint. + // The executor will append "/v1/publishers/google/models/{model}:action" to this. + // Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..." + BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"` + + // ProxyURL optionally overrides the global proxy for this API key. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // Headers optionally adds extra HTTP headers for requests sent with this key. + // Commonly used for cookies, user-agent, and other authentication headers. + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // Models defines the model configurations including aliases for routing. + Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"` +} + +// VertexCompatModel represents a model configuration for Vertex compatibility, +// including the actual model name and its alias for API routing. +type VertexCompatModel struct { + // Name is the actual model name used by the external provider. + Name string `yaml:"name" json:"name"` + + // Alias is the model name alias that clients will use to reference this model. + Alias string `yaml:"alias" json:"alias"` +} + +// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. +func (cfg *Config) SanitizeVertexCompatKeys() { + if cfg == nil { + return + } + + seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey)) + out := cfg.VertexCompatAPIKey[:0] + for i := range cfg.VertexCompatAPIKey { + entry := cfg.VertexCompatAPIKey[i] + entry.APIKey = strings.TrimSpace(entry.APIKey) + if entry.APIKey == "" { + continue + } + entry.BaseURL = strings.TrimSpace(entry.BaseURL) + if entry.BaseURL == "" { + // BaseURL is required for vertex-compat keys + continue + } + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.Headers = NormalizeHeaders(entry.Headers) + + // Sanitize models: remove entries without valid alias + sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models)) + for _, model := range entry.Models { + model.Alias = strings.TrimSpace(model.Alias) + model.Name = strings.TrimSpace(model.Name) + if model.Alias != "" && model.Name != "" { + sanitizedModels = append(sanitizedModels, model) + } + } + entry.Models = sanitizedModels + + // Use API key + base URL as uniqueness key + uniqueKey := entry.APIKey + "|" + entry.BaseURL + if _, exists := seen[uniqueKey]; exists { + continue + } + seen[uniqueKey] = struct{}{} + out = append(out, entry) + } + cfg.VertexCompatAPIKey = out +} diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index bd4242a11..eeb7356e9 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -44,6 +44,22 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { // Identifier returns provider key for manager routing. func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } +// GeminiVertexCompatExecutor is a thin wrapper around GeminiVertexExecutor +// that provides the correct identifier for vertex-compat routing. +type GeminiVertexCompatExecutor struct { + *GeminiVertexExecutor +} + +// NewGeminiVertexCompatExecutor constructs the Vertex-compatible executor. +func NewGeminiVertexCompatExecutor(cfg *config.Config) *GeminiVertexCompatExecutor { + return &GeminiVertexCompatExecutor{ + GeminiVertexExecutor: NewGeminiVertexExecutor(cfg), + } +} + +// Identifier returns provider key for manager routing. +func (e *GeminiVertexCompatExecutor) Identifier() string { return "vertex-compat" } + // PrepareRequest is a no-op for Vertex. func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil @@ -51,11 +67,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A // Execute handles non-streaming requests. func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return resp, errCreds + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return resp, errCreds + } + return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// ExecuteStream handles SSE streaming for Vertex. +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return nil, errCreds + } + return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// CountTokens calls Vertex countTokens endpoint. +func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return cliproxyexecutor.Response{}, errCreds + } + return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// countTokensWithServiceAccount handles token counting using service account credentials. +func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// countTokensWithAPIKey handles token counting using API key credentials. +func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// Refresh is a no-op for service account based credentials. +func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + return auth, nil +} + +// executeWithServiceAccount handles authentication using service account credentials. +// This method contains the original service account authentication logic. +func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -149,13 +392,105 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return resp, nil } -// ExecuteStream handles SSE streaming for Vertex. -func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return nil, errCreds +// executeWithAPIKey handles authentication using API key credentials. +// This method follows the vertex-compat pattern for API key authentication. +func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + + action := "generateContent" + if req.Metadata != nil { + if a, _ := req.Metadata["action"].(string); a == "countTokens" { + action = "countTokens" + } + } + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") + + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if errNewReq != nil { + return resp, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return resp, errDo } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + reporter.publish(ctx, parseGeminiUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} +// executeStreamWithServiceAccount handles streaming authentication using service account credentials. +func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -266,42 +601,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return stream, nil } -// CountTokens calls Vertex countTokens endpoint. -func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return cliproxyexecutor.Response{}, errCreds - } +// executeStreamWithAPIKey handles streaming authentication using API key credentials. +func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } - translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) - baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") - httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if errNewReq != nil { - return cliproxyexecutor.Response{}, errNewReq + return nil, errNewReq } httpReq.Header.Set("Content-Type", "application/json") - if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } else if errTok != nil { - log.Errorf("vertex executor: access token error: %v", errTok) - return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) } applyGeminiHeaders(httpReq, auth) @@ -315,7 +652,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau URL: url, Method: http.MethodPost, Headers: httpReq.Header.Clone(), - Body: translatedReq, + Body: body, Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -327,38 +664,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { recordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo + return nil, errDo } - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("vertex executor: close response body error: %v", errClose) - } - }() recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - data, errRead := io.ReadAll(httpResp.Body) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} } - count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil -} -// Refresh is a no-op for service account based credentials. -func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - return auth, nil + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 20_971_520) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseGeminiStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + }() + return stream, nil } // vertexCreds extracts project, location and raw service account JSON from auth metadata. @@ -401,6 +753,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou return projectID, location, saJSON, nil } +// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern. +func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { + if a == nil { + return "", "" + } + if a.Attributes != nil { + apiKey = a.Attributes["api_key"] + baseURL = a.Attributes["base_url"] + } + if apiKey == "" && a.Metadata != nil { + if v, ok := a.Metadata["access_token"].(string); ok { + apiKey = v + } + } + return +} + func vertexBaseURL(location string) string { loc := strings.TrimSpace(location) if loc == "" { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index af72212d0..2798b724f 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -498,6 +498,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str return hex.EncodeToString(sum[:]) } +func computeVertexCompatModelsHash(models []config.VertexCompatModel) string { + if len(models) == 0 { + return "" + } + data, err := json.Marshal(models) + if err != nil || len(data) == 0 { + return "" + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + // computeClaudeModelsHash returns a stable hash for Claude model aliases. func computeClaudeModelsHash(models []config.ClaudeModel) string { if len(models) == 0 { @@ -920,8 +932,8 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string // no legacy clients to unregister // Create new API key clients based on the new config - geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) - totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) + totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount log.Debugf("loaded %d API key clients", totalAPIKeyClients) var authFileCount int @@ -964,7 +976,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.clientsMutex.Unlock() } - totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount // Ensure consumers observe the new configuration before auth updates dispatch. if w.reloadCallback != nil { @@ -974,10 +986,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.refreshAuthState() - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", + log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex-compat keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", totalNewClients, authFileCount, geminiAPIKeyCount, + vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount, @@ -1092,6 +1105,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey") out = append(out, a) } + // Claude API keys -> synthesize auths for i := range cfg.ClaudeKey { ck := cfg.ClaudeKey[i] @@ -1258,6 +1272,42 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } } } + + // Process Vertex compatibility providers + for i := range cfg.VertexCompatAPIKey { + compat := &cfg.VertexCompatAPIKey[i] + providerName := "vertex-compat" + base := strings.TrimSpace(compat.BaseURL) + + key := strings.TrimSpace(compat.APIKey) + proxyURL := strings.TrimSpace(compat.ProxyURL) + idKind := fmt.Sprintf("vertex-compatibility:%s", base) + id, token := idGen.next(idKind, key, base, proxyURL) + attrs := map[string]string{ + "source": fmt.Sprintf("config:vertex-compatibility[%s]", token), + "base_url": base, + "provider_key": providerName, + } + if key != "" { + attrs["api_key"] = key + } + if hash := computeVertexCompatModelsHash(compat.Models); hash != "" { + attrs["models_hash"] = hash + } + addConfigHeadersToAttrs(compat.Headers, attrs) + a := &coreauth.Auth{ + ID: id, + Provider: providerName, + Label: "Vertex Compatibility", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + out = append(out, a) + } + // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) entries, _ := os.ReadDir(w.authDir) for _, e := range entries { @@ -1474,8 +1524,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int { return authFileCount } -func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { +func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { geminiAPIKeyCount := 0 + vertexCompatAPIKeyCount := 0 claudeAPIKeyCount := 0 codexAPIKeyCount := 0 openAICompatCount := 0 @@ -1484,6 +1535,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { // Stateless executor handles Gemini API keys; avoid constructing legacy clients. geminiAPIKeyCount += len(cfg.GeminiKey) } + if len(cfg.VertexCompatAPIKey) > 0 { + vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey) + } if len(cfg.ClaudeKey) > 0 { claudeAPIKeyCount += len(cfg.ClaudeKey) } @@ -1501,7 +1555,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { } } } - return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount + return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount } func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { diff --git a/sdk/cliproxy/providers.go b/sdk/cliproxy/providers.go index a5810336a..401885f5c 100644 --- a/sdk/cliproxy/providers.go +++ b/sdk/cliproxy/providers.go @@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider { type apiKeyClientProvider struct{} func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) { - geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) + geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) if ctx != nil { select { case <-ctx.Done(): @@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A } } return &APIKeyClientResult{ - GeminiKeyCount: geminiCount, - ClaudeKeyCount: claudeCount, - CodexKeyCount: codexCount, - OpenAICompatCount: openAICompat, + GeminiKeyCount: geminiCount, + VertexCompatKeyCount: vertexCompatCount, + ClaudeKeyCount: claudeCount, + CodexKeyCount: codexCount, + OpenAICompatCount: openAICompat, }, nil } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index c2ebba8d2..f0b6bf537 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -324,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName if len(a.Attributes) > 0 { providerKey = strings.TrimSpace(a.Attributes["provider_key"]) compatName = strings.TrimSpace(a.Attributes["compat_name"]) - if providerKey != "" || compatName != "" { + if compatName != "" { if providerKey == "" { providerKey = compatName } @@ -362,6 +362,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) case "vertex": s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg)) + case "vertex-compat": + s.coreManager.RegisterExecutor(executor.NewGeminiVertexCompatExecutor(s.cfg)) case "gemini-cli": s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "aistudio": @@ -498,7 +500,7 @@ func (s *Service) Run(ctx context.Context) error { }() time.Sleep(100 * time.Millisecond) - fmt.Println("API server started successfully") + fmt.Printf("API server started successfully on: %d\n", s.cfg.Port) if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) @@ -680,6 +682,35 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() models = applyExcludedModels(models, excluded) + case "vertex-compat": + // Handle Vertex AI compatibility providers with custom model definitions + if s.cfg != nil && len(s.cfg.VertexCompatAPIKey) > 0 { + // Create models for all Vertex compatibility providers + allModels := make([]*ModelInfo, 0) + for i := range s.cfg.VertexCompatAPIKey { + compat := &s.cfg.VertexCompatAPIKey[i] + for j := range compat.Models { + m := compat.Models[j] + // Use alias as model ID, fallback to name if alias is empty + modelID := m.Alias + if modelID == "" { + modelID = m.Name + } + if modelID != "" { + allModels = append(allModels, &ModelInfo{ + ID: modelID, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "vertex-compat", + Type: "vertex-compat", + DisplayName: m.Name, + }) + } + } + } + models = allModels + } + case "gemini-cli": models = registry.GetGeminiCLIModels() models = applyExcludedModels(models, excluded) diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index b44185d17..42c7c4881 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -49,19 +49,21 @@ type APIKeyClientProvider interface { Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) } -// APIKeyClientResult contains API key based clients along with type counts. -// It provides metadata about the number of clients loaded for each provider type. +// APIKeyClientResult is returned by APIKeyClientProvider.Load() type APIKeyClientResult struct { - // GeminiKeyCount is the number of Gemini API key clients loaded. + // GeminiKeyCount is the number of Gemini API keys loaded GeminiKeyCount int - // ClaudeKeyCount is the number of Claude API key clients loaded. + // VertexCompatKeyCount is the number of Vertex-compatible API keys loaded + VertexCompatKeyCount int + + // ClaudeKeyCount is the number of Claude API keys loaded ClaudeKeyCount int - // CodexKeyCount is the number of Codex API key clients loaded. + // CodexKeyCount is the number of Codex API keys loaded CodexKeyCount int - // OpenAICompatCount is the number of OpenAI-compatible API key clients loaded. + // OpenAICompatCount is the number of OpenAI compatibility API keys loaded OpenAICompatCount int } From 0fd2abbc3bb026ba4fb124d71051a13cd4ccb0d0 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 2 Dec 2025 09:12:53 +0800 Subject: [PATCH 5/6] **refactor(cliproxy, config): remove vertex-compat flow, streamline Vertex API key handling** - Removed `vertex-compat` executor and related configuration. - Consolidated Vertex compatibility checks into `vertex` handling with `apikey`-based model resolution. - Streamlined model generation logic for Vertex API key entries. --- config.example.yaml | 13 +++ internal/config/vertex_compat.go | 2 +- .../executor/gemini_vertex_executor.go | 17 --- internal/watcher/watcher.go | 13 ++- sdk/cliproxy/service.go | 106 +++++++++++++----- 5 files changed, 97 insertions(+), 54 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 0685a83af..3086ca783 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -145,6 +145,19 @@ ws-auth: false # - name: "moonshotai/kimi-k2:free" # The actual model name. # alias: "kimi-k2" # The alias used in the API. +# Vertex API keys (Vertex-compatible endpoints, use API key + base URL) +#vertex-api-key: +# - api-key: "vk-123..." # x-goog-api-key header +# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api +# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override +# headers: +# X-Custom-Header: "custom-value" +# models: # optional: map aliases to upstream model names +# - name: "gemini-2.0-flash" # upstream model name +# alias: "vertex-flash" # client-visible alias +# - name: "gemini-1.5-pro" +# alias: "vertex-pro" + #payload: # Optional payload configuration # default: # Default rules only set parameters when they are missing in the payload. # - models: diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index a8d94ccb8..1257dd62b 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -55,7 +55,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() { } entry.BaseURL = strings.TrimSpace(entry.BaseURL) if entry.BaseURL == "" { - // BaseURL is required for vertex-compat keys + // BaseURL is required for Vertex API key entries continue } entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index eeb7356e9..de4ba072c 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -44,22 +44,6 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { // Identifier returns provider key for manager routing. func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } -// GeminiVertexCompatExecutor is a thin wrapper around GeminiVertexExecutor -// that provides the correct identifier for vertex-compat routing. -type GeminiVertexCompatExecutor struct { - *GeminiVertexExecutor -} - -// NewGeminiVertexCompatExecutor constructs the Vertex-compatible executor. -func NewGeminiVertexCompatExecutor(cfg *config.Config) *GeminiVertexCompatExecutor { - return &GeminiVertexCompatExecutor{ - GeminiVertexExecutor: NewGeminiVertexExecutor(cfg), - } -} - -// Identifier returns provider key for manager routing. -func (e *GeminiVertexCompatExecutor) Identifier() string { return "vertex-compat" } - // PrepareRequest is a no-op for Vertex. func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil @@ -393,7 +377,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } // executeWithAPIKey handles authentication using API key credentials. -// This method follows the vertex-compat pattern for API key authentication. func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 2798b724f..02abb7433 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -986,7 +986,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.refreshAuthState() - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex-compat keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", + log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", totalNewClients, authFileCount, geminiAPIKeyCount, @@ -1273,18 +1273,18 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } } - // Process Vertex compatibility providers + // Process Vertex API key providers (Vertex-compatible endpoints) for i := range cfg.VertexCompatAPIKey { compat := &cfg.VertexCompatAPIKey[i] - providerName := "vertex-compat" + providerName := "vertex" base := strings.TrimSpace(compat.BaseURL) key := strings.TrimSpace(compat.APIKey) proxyURL := strings.TrimSpace(compat.ProxyURL) - idKind := fmt.Sprintf("vertex-compatibility:%s", base) + idKind := fmt.Sprintf("vertex:apikey:%s", base) id, token := idGen.next(idKind, key, base, proxyURL) attrs := map[string]string{ - "source": fmt.Sprintf("config:vertex-compatibility[%s]", token), + "source": fmt.Sprintf("config:vertex-apikey[%s]", token), "base_url": base, "provider_key": providerName, } @@ -1298,13 +1298,14 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { a := &coreauth.Auth{ ID: id, Provider: providerName, - Label: "Vertex Compatibility", + Label: "vertex-apikey", Status: coreauth.StatusActive, ProxyURL: proxyURL, Attributes: attrs, CreatedAt: now, UpdatedAt: now, } + applyAuthExcludedModelsMeta(a, cfg, nil, "apikey") out = append(out, a) } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index f0b6bf537..8b9a66398 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -362,8 +362,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) case "vertex": s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg)) - case "vertex-compat": - s.coreManager.RegisterExecutor(executor.NewGeminiVertexCompatExecutor(s.cfg)) case "gemini-cli": s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "aistudio": @@ -681,36 +679,12 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "vertex": // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() - models = applyExcludedModels(models, excluded) - case "vertex-compat": - // Handle Vertex AI compatibility providers with custom model definitions - if s.cfg != nil && len(s.cfg.VertexCompatAPIKey) > 0 { - // Create models for all Vertex compatibility providers - allModels := make([]*ModelInfo, 0) - for i := range s.cfg.VertexCompatAPIKey { - compat := &s.cfg.VertexCompatAPIKey[i] - for j := range compat.Models { - m := compat.Models[j] - // Use alias as model ID, fallback to name if alias is empty - modelID := m.Alias - if modelID == "" { - modelID = m.Name - } - if modelID != "" { - allModels = append(allModels, &ModelInfo{ - ID: modelID, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "vertex-compat", - Type: "vertex-compat", - DisplayName: m.Name, - }) - } - } + if authKind == "apikey" { + if entry := s.resolveConfigVertexCompatKey(a); entry != nil && len(entry.Models) > 0 { + models = buildVertexCompatConfigModels(entry) } - models = allModels } - + models = applyExcludedModels(models, excluded) case "gemini-cli": models = registry.GetGeminiCLIModels() models = applyExcludedModels(models, excluded) @@ -905,6 +879,40 @@ func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey return nil } +func (s *Service) resolveConfigVertexCompatKey(auth *coreauth.Auth) *config.VertexCompatKey { + if auth == nil || s.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range s.cfg.VertexCompatAPIKey { + entry := &s.cfg.VertexCompatAPIKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range s.cfg.VertexCompatAPIKey { + entry := &s.cfg.VertexCompatAPIKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} + func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey { if auth == nil || s.cfg == nil { return nil @@ -1023,6 +1031,44 @@ func matchWildcard(pattern, value string) bool { return true } +func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { + if entry == nil || len(entry.Models) == 0 { + return nil + } + now := time.Now().Unix() + out := make([]*ModelInfo, 0, len(entry.Models)) + seen := make(map[string]struct{}, len(entry.Models)) + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if alias == "" { + alias = name + } + if alias == "" { + continue + } + key := strings.ToLower(alias) + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + display := name + if display == "" { + display = alias + } + out = append(out, &ModelInfo{ + ID: alias, + Object: "model", + Created: now, + OwnedBy: "vertex", + Type: "vertex", + DisplayName: display, + }) + } + return out +} + func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { if entry == nil || len(entry.Models) == 0 { return nil From 1434bc38e58b5b33296b3b8ab231c499098890af Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 2 Dec 2025 11:34:38 +0800 Subject: [PATCH 6/6] **refactor(registry): remove Qwen3-Coder from model definitions** --- internal/registry/model_definitions.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index fd4bd4284..21e5eb604 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -954,7 +954,6 @@ func GetIFlowModels() []*ModelInfo { }{ {ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600}, {ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800}, - {ID: "qwen3-coder", DisplayName: "Qwen3-Coder-480B-A35B", Description: "Qwen3 Coder 480B A35B", Created: 1753228800}, {ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000}, {ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000}, {ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},