diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 000000000..309ff5cb0 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +golang 1.25.5 diff --git a/README.md b/README.md index fc75e57fc..cc7fe7b3e 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB - iFlow multi-account load balancing - OpenAI Codex multi-account load balancing - OpenAI-compatible upstream providers via config (e.g., OpenRouter) +- **Global model aliases** to map any model name to another (e.g., `claude-haiku` → `claude-3-5-haiku-20241022`) - Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`) ## Getting Started diff --git a/README_CN.md b/README_CN.md index dbb9609d6..e7b02320b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -54,6 +54,7 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元 - 支持 iFlow 多账户轮询 - 支持 OpenAI Codex 多账户轮询 - 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter) +- **全局模型别名**,可将任何模型名称映射到另一个名称(例如 `claude-haiku` → `claude-3-5-haiku-20241022`) - 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`) ## 新手入门 diff --git a/config.example.yaml b/config.example.yaml index f6390d2ff..735252674 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -199,16 +199,37 @@ ws-auth: false # - "tstars2.0" # Optional payload configuration -# payload: -# default: # Default rules only set parameters when they are missing in the payload. -# - models: -# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*") -# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex -# params: # JSON path (gjson/sjson syntax) -> value -# "generationConfig.thinkingConfig.thinkingBudget": 32768 -# override: # Override rules always set parameters, overwriting any existing values. -# - models: -# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*") -# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex -# params: # JSON path (gjson/sjson syntax) -> value -# "reasoning.effort": "high" +payload: + default: # Default rules only set parameters when they are missing in the payload. + - models: + - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*") + protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex + params: # JSON path (gjson/sjson syntax) -> value + "generationConfig.thinkingConfig.thinkingBudget": 32768 + override: # Override rules always set parameters, overwriting any existing values. + - models: + - name: "gpt-*" # Supports wildcards (e.g., "gpt-*") + protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex + params: # JSON path (gjson/sjson syntax) -> value + "reasoning.effort": "high" + +# Global model aliases +# Define custom aliases that can be used across all providers to map to specific model IDs. +# This eliminates the need for per-provider alias configuration. +# Supports both single-target strings and multi-target lists for provider aggregation. +aliases: + # Single target alias + "claude-haiku-4-5": "claude-haiku-4-5-20251001" + + # Multi-target alias (aggregates providers from all listed models) + "claude-sonnet-4-5": + - "claude-sonnet-4-5-20250929" + - "gemini-claude-sonnet-4-5" + + "claude-opus-4-5": + - "claude-opus-4-5-20251101" + - "gemini-claude-opus-4-5-thinking" + + # Custom aliases - any name you prefer + "best-model": "claude-sonnet-4-5-20250929" + "fast-model": "claude-haiku-4-5-20251001" diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index ae2929822..315a72199 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -160,6 +160,12 @@ func (h *Handler) PutConfigYAML(c *gin.Context) { return } h.cfg = newCfg + + // Update global aliases from the new config + if newCfg.Aliases != nil { + util.SetAliases(newCfg.Aliases) + } + c.JSON(http.StatusOK, gin.H{"ok": true, "changed": []string{"config"}}) } @@ -241,3 +247,81 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } + +// Aliases +func (h *Handler) GetAliases(c *gin.Context) { + aliases := h.cfg.Aliases + if aliases == nil { + aliases = make(map[string][]string) + } + c.JSON(200, gin.H{"aliases": aliases}) +} + +func (h *Handler) PutAliases(c *gin.Context) { + var body struct { + Aliases *map[string][]string `json:"aliases"` + } + if err := c.ShouldBindJSON(&body); err != nil || body.Aliases == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + h.cfg.Aliases = *body.Aliases + h.persist(c) + // Update global aliases after persisting + if h.cfg.Aliases != nil { + util.SetAliases(h.cfg.Aliases) + } +} + +func (h *Handler) PutAlias(c *gin.Context) { + alias := c.Param("alias") + if alias == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing alias"}) + return + } + var body struct { + Target *string `json:"target"` + Targets *[]string `json:"targets"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + + var targets []string + if body.Targets != nil { + targets = *body.Targets + } else if body.Target != nil { + targets = []string{*body.Target} + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing target or targets"}) + return + } + + if h.cfg.Aliases == nil { + h.cfg.Aliases = make(map[string][]string) + } + h.cfg.Aliases[alias] = targets + if !h.persist(c) { + return + } + // Update global aliases after persisting + util.SetAliases(h.cfg.Aliases) + c.JSON(200, gin.H{"alias": alias, "targets": targets}) +} + +func (h *Handler) DeleteAlias(c *gin.Context) { + alias := c.Param("alias") + if alias == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing alias"}) + return + } + if h.cfg.Aliases != nil { + delete(h.cfg.Aliases, alias) + } + util.SetAliases(h.cfg.Aliases) + if !h.persist(c) { + return + } + c.JSON(200, gin.H{"deleted": alias}) +} diff --git a/internal/api/server.go b/internal/api/server.go index 80c30ebc4..429cdd633 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -496,6 +496,11 @@ func (s *Server) registerManagementRoutes() { mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL) mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL) + mgmt.GET("/aliases", s.mgmt.GetAliases) + mgmt.PUT("/aliases", s.mgmt.PutAliases) + mgmt.PUT("/aliases/:alias", s.mgmt.PutAlias) + mgmt.DELETE("/aliases/:alias", s.mgmt.DeleteAlias) + mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject) mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) @@ -821,6 +826,11 @@ func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) { // - clients: The new slice of AI service clients // - cfg: The new application configuration func (s *Server) UpdateClients(cfg *config.Config) { + // Set global aliases from config + if cfg.Aliases != nil { + util.SetAliases(cfg.Aliases) + } + // Reconstruct old config from YAML snapshot to avoid reference sharing issues var oldCfg *config.Config if len(s.oldConfigYaml) > 0 { diff --git a/internal/config/config.go b/internal/config/config.go index bc6ae9d8d..9c4764c7d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -91,9 +91,74 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // Aliases defines global model aliases that can be used across all providers. + // Map key is the alias name, value is the target model ID. + Aliases Aliases `yaml:"aliases" json:"aliases"` + legacyMigrationPending bool `yaml:"-" json:"-"` } +// Aliases defines global model aliases that can be used across all providers. +// Map key is the alias name, value is a list of target model IDs. +type Aliases map[string][]string + +// UnmarshalYAML implements custom YAML unmarshalling to accept either a mapping +// or a sequence of single-key mapping entries. It also supports both single string +// and sequence of strings for alias targets. +func (a *Aliases) UnmarshalYAML(node *yaml.Node) error { + if node == nil { + return nil + } + + m := make(map[string][]string) + + processMapping := func(n *yaml.Node) error { + if n.Kind != yaml.MappingNode { + return fmt.Errorf("expected mapping node, got kind=%d", n.Kind) + } + for i := 0; i+1 < len(n.Content); i += 2 { + keyNode := n.Content[i] + valNode := n.Content[i+1] + key := keyNode.Value + + switch valNode.Kind { + case yaml.ScalarNode: + m[key] = append(m[key], valNode.Value) + case yaml.SequenceNode: + var targets []string + if err := valNode.Decode(&targets); err != nil { + return err + } + m[key] = append(m[key], targets...) + default: + return fmt.Errorf("alias target for %s must be a string or sequence, got kind=%d", key, valNode.Kind) + } + } + return nil + } + + switch node.Kind { + case yaml.MappingNode: + if err := processMapping(node); err != nil { + return err + } + case yaml.SequenceNode: + for _, elm := range node.Content { + if elm == nil { + continue + } + if err := processMapping(elm); err != nil { + return err + } + } + default: + return fmt.Errorf("aliases must be a mapping or sequence, got kind=%d", node.Kind) + } + + *a = m + return nil +} + // TLSConfig holds HTTPS server settings. type TLSConfig struct { // Enable toggles HTTPS server mode. diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index d4f84481c..d7e5377a5 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -95,6 +95,8 @@ type ModelRegistry struct { clientModelInfos map[string]map[string]*ModelInfo // clientProviders maps client ID to its provider identifier clientProviders map[string]string + // aliases maps model ID to its target model IDs + aliases map[string][]string // mutex ensures thread-safe access to the registry mutex *sync.RWMutex } @@ -111,12 +113,35 @@ func GetGlobalRegistry() *ModelRegistry { clientModels: make(map[string][]string), clientModelInfos: make(map[string]map[string]*ModelInfo), clientProviders: make(map[string]string), + aliases: make(map[string][]string), mutex: &sync.RWMutex{}, } }) return globalRegistry } +// SetAliases sets the global model aliases in the registry. +func (r *ModelRegistry) SetAliases(aliases map[string][]string) { + r.mutex.Lock() + defer r.mutex.Unlock() + if aliases == nil { + r.aliases = make(map[string][]string) + } else { + r.aliases = aliases + } +} + +// ResolveAlias resolves a model ID through the registry's alias map. +// It returns all target model IDs for the given alias. +func (r *ModelRegistry) ResolveAlias(modelID string) []string { + r.mutex.RLock() + defer r.mutex.RUnlock() + if targets, ok := r.aliases[modelID]; ok { + return targets + } + return nil +} + // RegisterClient registers a client and its supported models // Parameters: // - clientID: Unique identifier for the client @@ -545,6 +570,47 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { log.Debugf("Resumed client %s for model %s", clientID, modelID) } +// ResolveModelForClient returns the model ID that the client actually supports for a given requested model ID. +// It prioritizes direct matches over alias matches. +func (r *ModelRegistry) ResolveModelForClient(clientID, modelID string) string { + clientID = strings.TrimSpace(clientID) + modelID = strings.TrimSpace(modelID) + if clientID == "" || modelID == "" { + return modelID + } + + r.mutex.RLock() + defer r.mutex.RUnlock() + + models, exists := r.clientModels[clientID] + if !exists || len(models) == 0 { + return modelID + } + + // 1. Check for direct match + for _, id := range models { + if strings.EqualFold(strings.TrimSpace(id), modelID) { + return modelID + } + } + + // 2. Check for alias matches + if targets, ok := r.aliases[modelID]; ok { + for _, target := range targets { + if target == modelID { + continue + } + for _, id := range models { + if strings.EqualFold(strings.TrimSpace(id), target) { + return target + } + } + } + } + + return modelID +} + // ClientSupportsModel reports whether the client registered support for modelID. func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { clientID = strings.TrimSpace(clientID) @@ -561,32 +627,74 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { return false } + targets := r.aliases[modelID] + for _, id := range models { - if strings.EqualFold(strings.TrimSpace(id), modelID) { + trimmedID := strings.TrimSpace(id) + if strings.EqualFold(trimmedID, modelID) { return true } + for _, target := range targets { + if strings.EqualFold(trimmedID, target) { + return true + } + } } return false } -// GetAvailableModels returns all models that have at least one available client -// Parameters: -// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini") -// -// Returns: -// - []map[string]any: List of available models in the requested format -func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any { - r.mutex.RLock() - defer r.mutex.RUnlock() +// getModelCountLocked returns the number of available clients for a specific model. +// It assumes the mutex is already held for reading. +func (r *ModelRegistry) getModelCountLocked(modelID string) int { + countFor := func(id string) int { + registration, exists := r.models[id] + if !exists || registration == nil { + return 0 + } + now := time.Now() + quotaExpiredDuration := 5 * time.Minute - models := make([]map[string]any, 0) - quotaExpiredDuration := 5 * time.Minute + // Count clients that have exceeded quota but haven't recovered yet + expiredClients := 0 + for _, quotaTime := range registration.QuotaExceededClients { + if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + expiredClients++ + } + } + suspendedClients := 0 + if registration.SuspendedClients != nil { + suspendedClients = len(registration.SuspendedClients) + } + result := registration.Count - expiredClients - suspendedClients + if result < 0 { + return 0 + } + return result + } + + total := countFor(modelID) + if targets, ok := r.aliases[modelID]; ok { + for _, target := range targets { + if target != modelID { + total += countFor(target) + } + } + } + return total +} - for _, registration := range r.models { - // Check if model has any non-quota-exceeded clients +// isModelAvailableLocked checks if a model has any available clients (including through aliases). +// It assumes the mutex is already held for reading. +func (r *ModelRegistry) isModelAvailableLocked(modelID string) bool { + check := func(id string) bool { + registration, exists := r.models[id] + if !exists || registration == nil { + return false + } availableClients := registration.Count now := time.Now() + quotaExpiredDuration := 5 * time.Minute // Count clients that have exceeded quota but haven't recovered yet expiredClients := 0 @@ -614,10 +722,74 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any } // Include models that have available clients, or those solely cooling down. - if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { + return effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) + } + + if check(modelID) { + return true + } + if targets, ok := r.aliases[modelID]; ok { + for _, target := range targets { + if target != modelID && check(target) { + return true + } + } + } + return false +} + +// GetAvailableModels returns all models that have at least one available client +// Parameters: +// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini") +// +// Returns: +// - []map[string]any: List of available models in the requested format +func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any { + r.mutex.RLock() + defer r.mutex.RUnlock() + + models := make([]map[string]any, 0) + processedIDs := make(map[string]struct{}) + + // 1. Process all models in r.models + for modelID, registration := range r.models { + if r.isModelAvailableLocked(modelID) { model := r.convertModelToMap(registration.Info, handlerType) if model != nil { models = append(models, model) + processedIDs[modelID] = struct{}{} + } + } + } + + // 2. Process all aliases that aren't in r.models + for aliasID, targetIDs := range r.aliases { + if _, processed := processedIDs[aliasID]; processed { + continue + } + + if r.isModelAvailableLocked(aliasID) { + // We need ModelInfo for the alias. + // If it's not in r.models, we use the first available target's ModelInfo but change the ID. + var info *ModelInfo + if reg, ok := r.models[aliasID]; ok { + info = reg.Info + } else { + for _, targetID := range targetIDs { + if targetReg, ok := r.models[targetID]; ok { + info = cloneModelInfo(targetReg.Info) + info.ID = aliasID + break + } + } + } + + if info != nil { + model := r.convertModelToMap(info, handlerType) + if model != nil { + models = append(models, model) + processedIDs[aliasID] = struct{}{} + } } } } @@ -634,29 +806,7 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any func (r *ModelRegistry) GetModelCount(modelID string) int { r.mutex.RLock() defer r.mutex.RUnlock() - - if registration, exists := r.models[modelID]; exists { - now := time.Now() - quotaExpiredDuration := 5 * time.Minute - - // Count clients that have exceeded quota but haven't recovered yet - expiredClients := 0 - for _, quotaTime := range registration.QuotaExceededClients { - if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { - expiredClients++ - } - } - suspendedClients := 0 - if registration.SuspendedClients != nil { - suspendedClients = len(registration.SuspendedClients) - } - result := registration.Count - expiredClients - suspendedClients - if result < 0 { - return 0 - } - return result - } - return 0 + return r.getModelCountLocked(modelID) } // GetModelProviders returns provider identifiers that currently supply the given model @@ -669,8 +819,34 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { r.mutex.RLock() defer r.mutex.RUnlock() - registration, exists := r.models[modelID] - if !exists || registration == nil || len(registration.Providers) == 0 { + providersMap := make(map[string]int) + + // Helper to add providers from a registration + addProviders := func(id string) { + registration, exists := r.models[id] + if !exists || registration == nil { + return + } + for name, count := range registration.Providers { + if count > 0 { + providersMap[name] += count + } + } + } + + // Add providers for the direct model ID + addProviders(modelID) + + // Add providers for the aliased model IDs + if targets, ok := r.aliases[modelID]; ok { + for _, target := range targets { + if target != modelID { + addProviders(target) + } + } + } + + if len(providersMap) == 0 { return nil } @@ -678,29 +854,10 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { name string count int } - providers := make([]providerCount, 0, len(registration.Providers)) - // suspendedByProvider := make(map[string]int) - // if registration.SuspendedClients != nil { - // for clientID := range registration.SuspendedClients { - // if provider, ok := r.clientProviders[clientID]; ok && provider != "" { - // suspendedByProvider[provider]++ - // } - // } - // } - for name, count := range registration.Providers { - if count <= 0 { - continue - } - // adjusted := count - suspendedByProvider[name] - // if adjusted <= 0 { - // continue - // } - // providers = append(providers, providerCount{name: name, count: adjusted}) + providers := make([]providerCount, 0, len(providersMap)) + for name, count := range providersMap { providers = append(providers, providerCount{name: name, count: count}) } - if len(providers) == 0 { - return nil - } sort.Slice(providers, func(i, j int) bool { if providers[i].count == providers[j].count { @@ -717,13 +874,32 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { } // GetModelInfo returns the registered ModelInfo for the given model ID, if present. +// It resolves aliases if the direct model ID is not found in the registry. // Returns nil if the model is unknown to the registry. func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo { r.mutex.RLock() defer r.mutex.RUnlock() + + // 1. Try direct match if reg, ok := r.models[modelID]; ok && reg != nil { return reg.Info } + + // 2. Try alias resolution + if targets, ok := r.aliases[modelID]; ok { + for _, target := range targets { + if target == modelID { + continue + } + if reg, ok := r.models[target]; ok && reg != nil { + // Return a clone with the alias ID to match user expectation + info := cloneModelInfo(reg.Info) + info.ID = modelID + return info + } + } + } + return nil } diff --git a/internal/registry/model_registry_test.go b/internal/registry/model_registry_test.go new file mode 100644 index 000000000..e218d6c10 --- /dev/null +++ b/internal/registry/model_registry_test.go @@ -0,0 +1,219 @@ +package registry + +import ( + "sync" + "testing" +) + +func TestModelRegistry_Aliases(t *testing.T) { + r := &ModelRegistry{ + models: make(map[string]*ModelRegistration), + clientModels: make(map[string][]string), + clientModelInfos: make(map[string]map[string]*ModelInfo), + clientProviders: make(map[string]string), + aliases: make(map[string][]string), + mutex: &sync.RWMutex{}, + } + + // 1. Setup: Register a client with a real model + realModelID := "gpt-4-real" + r.RegisterClient("client1", "openai", []*ModelInfo{ + {ID: realModelID, OwnedBy: "openai"}, + }) + + // 2. Setup: Register another client with a model that will be an alias target + targetModelID := "claude-3-target" + r.RegisterClient("client2", "anthropic", []*ModelInfo{ + {ID: targetModelID, OwnedBy: "anthropic"}, + }) + + // 3. Setup: Define aliases + aliases := map[string][]string{ + "gpt-4-alias": {realModelID}, // Alias to a model that exists + "claude-alias": {targetModelID}, // Alias to another model that exists + "missing-alias": {"non-existent"}, + "multi-alias": {realModelID, targetModelID}, // Alias to multiple models + } + r.SetAliases(aliases) + + // Test GetModelProviders (Union behavior) + t.Run("GetModelProviders_Union", func(t *testing.T) { + // Direct model + providers := r.GetModelProviders(realModelID) + if len(providers) != 1 || providers[0] != "openai" { + t.Errorf("Expected [openai], got %v", providers) + } + + // Alias model + providers = r.GetModelProviders("gpt-4-alias") + if len(providers) != 1 || providers[0] != "openai" { + t.Errorf("Expected [openai] for alias, got %v", providers) + } + + // Multi alias + providers = r.GetModelProviders("multi-alias") + if len(providers) != 2 { + t.Errorf("Expected 2 providers for multi-alias, got %v", providers) + } + }) + + // Test GetModelCount (Union behavior) + t.Run("GetModelCount_Union", func(t *testing.T) { + count := r.GetModelCount(realModelID) + if count != 1 { + t.Errorf("Expected count 1 for real model, got %d", count) + } + + count = r.GetModelCount("gpt-4-alias") + if count != 1 { + t.Errorf("Expected count 1 for alias, got %d", count) + } + + count = r.GetModelCount("multi-alias") + if count != 2 { + t.Errorf("Expected count 2 for multi-alias, got %d", count) + } + }) + + // Test ClientSupportsModel + t.Run("ClientSupportsModel", func(t *testing.T) { + if !r.ClientSupportsModel("client1", realModelID) { + t.Error("client1 should support real model") + } + if !r.ClientSupportsModel("client1", "gpt-4-alias") { + t.Error("client1 should support alias model") + } + if !r.ClientSupportsModel("client1", "multi-alias") { + t.Error("client1 should support multi-alias") + } + if !r.ClientSupportsModel("client2", "multi-alias") { + t.Error("client2 should support multi-alias") + } + if r.ClientSupportsModel("client2", realModelID) { + t.Error("client2 should NOT support real model") + } + }) + + // Test ResolveModelForClient + t.Run("ResolveModelForClient", func(t *testing.T) { + // Direct match + resolved := r.ResolveModelForClient("client1", realModelID) + if resolved != realModelID { + t.Errorf("Expected %s, got %s", realModelID, resolved) + } + + // Alias match + resolved = r.ResolveModelForClient("client1", "gpt-4-alias") + if resolved != realModelID { + t.Errorf("Expected %s for alias, got %s", realModelID, resolved) + } + + // Multi alias match (client1 supports realModelID) + resolved = r.ResolveModelForClient("client1", "multi-alias") + if resolved != realModelID { + t.Errorf("Expected %s for multi-alias, got %s", realModelID, resolved) + } + + // Multi alias match (client2 supports targetModelID) + resolved = r.ResolveModelForClient("client2", "multi-alias") + if resolved != targetModelID { + t.Errorf("Expected %s for multi-alias, got %s", targetModelID, resolved) + } + + // No match + resolved = r.ResolveModelForClient("client2", "gpt-4-alias") + if resolved != "gpt-4-alias" { + t.Errorf("Expected gpt-4-alias (no resolution), got %s", resolved) + } + }) + + // Test GetAvailableModels + t.Run("GetAvailableModels", func(t *testing.T) { + models := r.GetAvailableModels("openai") + foundReal := false + foundAlias := false + foundMulti := false + for _, m := range models { + if m["id"] == realModelID { + foundReal = true + } + if m["id"] == "gpt-4-alias" { + foundAlias = true + } + if m["id"] == "multi-alias" { + foundMulti = true + } + } + if !foundReal { + t.Error("real model not found in available models") + } + if !foundAlias { + t.Error("alias model not found in available models") + } + if !foundMulti { + t.Error("multi-alias not found in available models") + } + }) + + // Test GetModelInfo + t.Run("GetModelInfo", func(t *testing.T) { + info := r.GetModelInfo(realModelID) + if info == nil || info.ID != realModelID { + t.Errorf("Expected info for %s", realModelID) + } + + info = r.GetModelInfo("gpt-4-alias") + if info == nil || info.ID != "gpt-4-alias" { + t.Errorf("Expected info for alias gpt-4-alias, got %v", info) + } + + info = r.GetModelInfo("multi-alias") + if info == nil || info.ID != "multi-alias" { + t.Errorf("Expected info for multi-alias, got %v", info) + } + }) + + // 4. Test Conflict: Model is both real and an alias + t.Run("Conflict_RealAndAlias", func(t *testing.T) { + conflictID := "conflict-model" + // Register conflict-model as a real model on client3 + r.RegisterClient("client3", "google", []*ModelInfo{ + {ID: conflictID, OwnedBy: "google"}, + }) + + // Set conflict-model as an alias to realModelID (gpt-4-real) + newAliases := make(map[string][]string) + for k, v := range r.aliases { + newAliases[k] = v + } + newAliases[conflictID] = []string{realModelID} + r.SetAliases(newAliases) + + // Now conflict-model should have providers from BOTH client3 (google) and client1 (openai) + providers := r.GetModelProviders(conflictID) + if len(providers) != 2 { + t.Errorf("Expected 2 providers for conflict model, got %v", providers) + } + + // Check if both are present + hasGoogle := false + hasOpenAI := false + for _, p := range providers { + if p == "google" { + hasGoogle = true + } + if p == "openai" { + hasOpenAI = true + } + } + if !hasGoogle || !hasOpenAI { + t.Errorf("Expected both google and openai providers, got %v", providers) + } + + // Count should be 2 + count := r.GetModelCount(conflictID) + if count != 2 { + t.Errorf("Expected count 2 for conflict model, got %d", count) + } + }) +} diff --git a/internal/util/provider.go b/internal/util/provider.go index 153513547..ae30f2efe 100644 --- a/internal/util/provider.go +++ b/internal/util/provider.go @@ -6,58 +6,52 @@ package util import ( "net/url" "strings" + "sync/atomic" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" log "github.com/sirupsen/logrus" ) +var ( + aliasValue atomic.Value // stores map[string][]string +) + // GetProviderName determines all AI service providers capable of serving a registered model. // It first queries the global model registry to retrieve the providers backing the supplied model name. -// When the model has not been registered yet, it falls back to legacy string heuristics to infer -// potential providers. -// -// Supported providers include (but are not limited to): -// - "gemini" for Google's Gemini family -// - "codex" for OpenAI GPT-compatible providers -// - "claude" for Anthropic models -// - "qwen" for Alibaba's Qwen models -// - "openai-compatibility" for external OpenAI-compatible providers // // Parameters: // - modelName: The name of the model to identify providers for. -// - cfg: The application configuration containing OpenAI compatibility settings. // // Returns: // - []string: All provider identifiers capable of serving the model, ordered by preference. func GetProviderName(modelName string) []string { - if modelName == "" { - return nil - } - - providers := make([]string, 0, 4) - seen := make(map[string]struct{}) - - appendProvider := func(name string) { - if name == "" { - return - } - if _, exists := seen[name]; exists { - return - } - seen[name] = struct{}{} - providers = append(providers, name) - } + providers, _ := GetProviderNameAndModel(modelName) + return providers +} - for _, provider := range registry.GetGlobalRegistry().GetModelProviders(modelName) { - appendProvider(provider) +// GetProviderNameAndModel determines all AI service providers capable of serving a model, +// and returns the resolved model name if an alias was used. +// It considers all possible sources: direct model name and global aliases. +// +// Parameters: +// - modelName: The name of the model to identify providers for. +// +// Returns: +// - []string: All provider identifiers capable of serving the model. +// - string: The resolved model name (original or alias target). +func GetProviderNameAndModel(modelName string) ([]string, string) { + if modelName == "" { + return nil, "" } - if len(providers) > 0 { - return providers - } + // GetModelProviders now handles the union of direct and aliased providers + providers := getProvidersFromRegistry(modelName) + return providers, modelName +} - return providers +func getProvidersFromRegistry(modelName string) []string { + return registry.GetGlobalRegistry().GetModelProviders(modelName) } // ResolveAutoModel resolves the "auto" model name to an actual available model. @@ -267,3 +261,37 @@ func shouldMaskQueryParam(key string) bool { } return false } + +// SetAliases sets the global alias map from the configuration. +// This should be called when the configuration is updated. +func SetAliases(aliases map[string][]string) { + if aliases == nil { + aliases = make(map[string][]string) + } + aliasValue.Store(aliases) + registry.GetGlobalRegistry().SetAliases(aliases) +} + +// ResolveAlias resolves a model name through the global alias map. +// If the model name is an alias, it returns the target model IDs. +// Otherwise, it returns nil. +// +// Parameters: +// - modelName: The model name or alias to resolve +// +// Returns: +// - []string: The resolved model IDs, or nil if not an alias +func ResolveAlias(modelName string) []string { + if modelName == "" { + return nil + } + val := aliasValue.Load() + if val == nil { + return nil + } + aliases := val.(map[string][]string) + if targets, exists := aliases[modelName]; exists { + return targets + } + return nil +} diff --git a/internal/util/provider_test.go b/internal/util/provider_test.go new file mode 100644 index 000000000..33e65c5ec --- /dev/null +++ b/internal/util/provider_test.go @@ -0,0 +1,75 @@ +package util + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestGetProviderNameAndModel(t *testing.T) { + reg := registry.GetGlobalRegistry() + + // Setup registry with models + modelID := "real-model" + aliasID := "alias-to-real" + + reg.RegisterClient("client1", "claude", []*registry.ModelInfo{ + {ID: modelID}, + }) + reg.RegisterClient("client2", "gemini", []*registry.ModelInfo{ + {ID: aliasID}, + }) + defer reg.UnregisterClient("client1") + defer reg.UnregisterClient("client2") + + // Setup aliases + aliases := map[string][]string{ + aliasID: {modelID}, + } + SetAliases(aliases) + defer SetAliases(nil) + + tests := []struct { + name string + input string + wantProviders []string + wantResolvedName string + }{ + { + name: "Union of direct and aliased providers", + input: aliasID, + wantProviders: []string{"claude", "gemini"}, // Order might vary but both should be there + wantResolvedName: aliasID, + }, + { + name: "Direct model only", + input: modelID, + wantProviders: []string{"claude"}, + wantResolvedName: modelID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + providers, resolvedName := GetProviderNameAndModel(tt.input) + if len(providers) != len(tt.wantProviders) { + t.Errorf("GetProviderNameAndModel() got %d providers, want %d", len(providers), len(tt.wantProviders)) + } + for _, want := range tt.wantProviders { + found := false + for _, got := range providers { + if got == want { + found = true + break + } + } + if !found { + t.Errorf("GetProviderNameAndModel() missing provider %s", want) + } + } + if resolvedName != tt.wantResolvedName { + t.Errorf("GetProviderNameAndModel() resolvedName = %v, want %v", resolvedName, tt.wantResolvedName) + } + }) + } +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index a544ef0c1..e21056611 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -541,8 +541,8 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string // Normalize the model name to handle dynamic thinking suffixes before determining the provider. normalizedModel, metadata = normalizeModelMetadata(resolvedModelName) - // Use the normalizedModel to get the provider name. - providers = util.GetProviderName(normalizedModel) + // Use the normalizedModel to get the provider name and resolved model name. + providers, normalizedModel = util.GetProviderNameAndModel(normalizedModel) if len(providers) == 0 && metadata != nil { if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok { if originalModel, okStr := originalRaw.(string); okStr { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 0d648d812..c3349de44 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -576,16 +576,23 @@ func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (str if auth == nil || model == "" { return model, metadata } + + // 1. Handle prefix stripping prefix := strings.TrimSpace(auth.Prefix) - if prefix == "" { - return model, metadata - } - needle := prefix + "/" - if !strings.HasPrefix(model, needle) { - return model, metadata + rewritten := model + if prefix != "" { + needle := prefix + "/" + if strings.HasPrefix(model, needle) { + rewritten = strings.TrimPrefix(model, needle) + metadata = stripPrefixFromMetadata(metadata, needle) + } } - rewritten := strings.TrimPrefix(model, needle) - return rewritten, stripPrefixFromMetadata(metadata, needle) + + // 2. Handle alias resolution for the client + // If the client was picked because it supports an alias target, we must use that target name. + rewritten = registry.GetGlobalRegistry().ResolveModelForClient(auth.ID, rewritten) + + return rewritten, metadata } func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {