diff --git a/config.example.yaml b/config.example.yaml index bbde75b62..67d40629c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -95,6 +95,9 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" +# models: +# - name: "gemini-2.5-flash" # upstream model name +# alias: "gemini-flash" # client alias mapped to the upstream model # excluded-models: # - "gemini-2.5-pro" # exclude specific models from this provider (exact match) # - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro) @@ -111,7 +114,7 @@ ws-auth: false # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # models: -# - name: "gpt-5-codex" # upstream model name +# - name: "gpt-5-codex" # upstream model name # alias: "codex-latest" # client alias mapped to the upstream model # excluded-models: # - "gpt-5.1" # exclude specific models (exact match) @@ -130,7 +133,7 @@ ws-auth: false # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # models: # - name: "claude-3-5-sonnet-20241022" # upstream model name -# alias: "claude-sonnet-latest" # client alias mapped to the upstream model +# alias: "claude-sonnet-latest" # client alias mapped to the upstream model # excluded-models: # - "claude-opus-4-5-20251101" # exclude specific models (exact match) # - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219) diff --git a/internal/config/config.go b/internal/config/config.go index 6beba5cde..7c30c4f9f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -280,6 +280,9 @@ type ClaudeModel struct { Alias string `yaml:"alias" json:"alias"` } +func (m ClaudeModel) GetName() string { return m.Name } +func (m ClaudeModel) GetAlias() string { return m.Alias } + // CodexKey represents the configuration for a Codex API key, // including the API key itself and an optional base URL for the API endpoint. type CodexKey struct { @@ -315,6 +318,9 @@ type CodexModel struct { Alias string `yaml:"alias" json:"alias"` } +func (m CodexModel) GetName() string { return m.Name } +func (m CodexModel) GetAlias() string { return m.Alias } + // GeminiKey represents the configuration for a Gemini API key, // including optional overrides for upstream base URL, proxy routing, and headers. type GeminiKey struct { @@ -330,6 +336,9 @@ type GeminiKey struct { // ProxyURL optionally overrides the global proxy for this API key. ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + // Models defines upstream model names and aliases for request routing. + Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"` + // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` @@ -337,6 +346,18 @@ type GeminiKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// GeminiModel describes a mapping between an alias and the actual upstream model name. +type GeminiModel struct { + // Name is the upstream model identifier used when issuing requests. + Name string `yaml:"name" json:"name"` + + // Alias is the client-facing model name that maps to Name. + Alias string `yaml:"alias" json:"alias"` +} + +func (m GeminiModel) GetName() string { return m.Name } +func (m GeminiModel) GetAlias() string { return m.Alias } + // KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication. type KiroKey struct { // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json) diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index a14f75bc8..94e162b7a 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -42,6 +42,9 @@ type VertexCompatModel struct { Alias string `yaml:"alias" json:"alias"` } +func (m VertexCompatModel) GetName() string { return m.Name } +func (m VertexCompatModel) GetAlias() string { return m.Alias } + // SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. func (cfg *Config) SanitizeVertexCompatKeys() { if cfg == nil { diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 93a539f20..f7e32bf6e 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -782,6 +782,32 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { } } +// LookupStaticModelInfo searches all static model definitions for a model by ID. +// Returns nil if no matching model is found. +func LookupStaticModelInfo(modelID string) *ModelInfo { + if modelID == "" { + return nil + } + allModels := [][]*ModelInfo{ + GetClaudeModels(), + GetGeminiModels(), + GetGeminiVertexModels(), + GetGeminiCLIModels(), + GetAIStudioModels(), + GetOpenAIModels(), + GetQwenModels(), + GetIFlowModels(), + } + for _, models := range allModels { + for _, m := range models { + if m != nil && m.ID == modelID { + return m + } + } + } + return nil +} + // GetGitHubCopilotModels returns the available models for GitHub Copilot. // These models are available through the GitHub Copilot API at api.githubcopilot.com. func GetGitHubCopilotModels() []*ModelInfo { diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 17c8170fc..38c348f28 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, if err != nil { return resp, err } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, @@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth if err != nil { return nil, err } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 9ade4fbbc..950141f04 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -76,11 +76,7 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") if isClaude { return e.executeClaudeNonStream(ctx, auth, req, opts) } @@ -114,7 +110,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -195,11 +191,6 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) @@ -214,7 +205,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -530,16 +521,12 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") + from := opts.SourceFormat to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") - translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) @@ -554,7 +541,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return nil, err @@ -692,11 +679,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("antigravity") respCtx := context.WithValue(ctx, "alt", opts.Alt) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 2fbb235b3..f74dc1e04 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } - } - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream) + body, _ = sjson.SetBytes(body, "model", model) // Inject thinking config based on model metadata for thinking variants - body = e.injectThinkingConfig(req.Model, req.Metadata, body) + body = e.injectThinkingConfig(model, req.Metadata, body) - if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { + if !strings.HasPrefix(model, "claude-3-5-haiku") { body = checkSystemInstructions(body) } - body = applyPayloadConfig(e.cfg, req.Model, body) + body = applyPayloadConfig(e.cfg, model, body) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) // Ensure max_tokens > thinking.budget_tokens when thinking is enabled - body = ensureMaxTokensForThinking(req.Model, body) + body = ensureMaxTokensForThinking(model, body) // Extract betas from body and convert to header var extraBetas []string @@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("claude") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + body, _ = sjson.SetBytes(body, "model", model) // Inject thinking config based on model metadata for thinking variants - body = e.injectThinkingConfig(req.Model, req.Metadata, body) + body = e.injectThinkingConfig(model, req.Metadata, body) body = checkSystemInstructions(body) - body = applyPayloadConfig(e.cfg, req.Model, body) + body = applyPayloadConfig(e.cfg, model, body) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) // Ensure max_tokens > thinking.budget_tokens when thinking is enabled - body = ensureMaxTokensForThinking(req.Model, body) + body = ensureMaxTokensForThinking(model, body) // Extract betas from body and convert to header var extraBetas []string @@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream) + body, _ = sjson.SetBytes(body, "model", model) - if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { + if !strings.HasPrefix(model, "claude-3-5-haiku") { body = checkSystemInstructions(body) } diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index f4d837512..78882541d 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body = NormalizeThinkingConfig(body, model, false) + if errValidate := ValidateThinkingConfig(body, model); errValidate != nil { return resp, errValidate } - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.DeleteBytes(body, "previous_response_id") @@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) - body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body = NormalizeThinkingConfig(body, model, false) + if errValidate := ValidateThinkingConfig(body, model); errValidate != nil { return nil, errValidate } - body = applyPayloadConfig(e.cfg, req.Model, body) + body = applyPayloadConfig(e.cfg, model, body) body, _ = sjson.DeleteBytes(body, "previous_response_id") - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", model) url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - - modelForCounting := upstreamModel + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) - body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.SetBytes(body, "stream", false) - enc, err := tokenizerForCodexModel(modelForCounting) + enc, err := tokenizerForCodexModel(model) if err != nil { return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err) } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index b171041ad..a3b758399 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -318,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut out := make(chan cliproxyexecutor.StreamChunk) stream = out - go func(resp *http.Response, reqBody []byte, attempt string) { + go func(resp *http.Response, reqBody []byte, attemptModel string) { defer close(out) defer func() { if errClose := resp.Body.Close(); errClose != nil { @@ -336,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut reporter.publish(ctx, detail) } if bytes.HasPrefix(line, dataTag) { - segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } } } - segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } @@ -365,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut appendAPIResponseChunk(ctx, e.cfg, data) reporter.publish(ctx, parseGeminiCLIUsage(data)) var param any - segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } - segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) + segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } @@ -417,6 +417,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. var lastStatus int var lastBody []byte + // The loop variable attemptModel is only used as the concrete model id sent to the upstream + // Gemini CLI endpoint when iterating fallback variants. for _, attemptModel := range models { payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) @@ -425,7 +427,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "request.safetySettings") payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) - payload = fixGeminiCLIImageAspectRatio(attemptModel, payload) + payload = fixGeminiCLIImageAspectRatio(req.Model, payload) tok, errTok := tokenSource.Token() if errTok != nil { diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index f211ba62a..d69044b89 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -77,19 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override + } // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - body = ApplyThinkingMetadata(body, req.Metadata, req.Model) - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + body = ApplyThinkingMetadata(body, req.Metadata, model) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) action := "generateContent" if req.Metadata != nil { @@ -98,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } } baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action) + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -173,21 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - body = ApplyThinkingMetadata(body, req.Metadata, req.Model) - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + body = ApplyThinkingMetadata(body, req.Metadata, model) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -287,19 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { apiKey, bearer := geminiCreds(auth) + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override + } + from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model) - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model) + translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(model, translatedReq) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + translatedReq, _ = sjson.SetBytes(translatedReq, "model", model) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens") requestBody := bytes.NewReader(translatedReq) @@ -398,6 +410,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string { return base } +func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string { + trimmed := strings.TrimSpace(alias) + if trimmed == "" { + return "" + } + + entry := e.resolveGeminiConfig(auth) + if entry == nil { + return "" + } + + normalizedModel, metadata := util.NormalizeThinkingModel(trimmed) + + // Candidate names to match against configured aliases/names. + candidates := []string{strings.TrimSpace(normalizedModel)} + if !strings.EqualFold(normalizedModel, trimmed) { + candidates = append(candidates, trimmed) + } + if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) { + candidates = append(candidates, original) + } + + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + modelAlias := strings.TrimSpace(model.Alias) + + for _, candidate := range candidates { + if candidate == "" { + continue + } + if modelAlias != "" && strings.EqualFold(modelAlias, candidate) { + if name != "" { + return name + } + return candidate + } + if name != "" && strings.EqualFold(name, candidate) { + return name + } + } + } + return "" +} + +func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey { + if auth == nil || e.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range e.cfg.GeminiKey { + entry := &e.cfg.GeminiKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range e.cfg.GeminiKey { + entry := &e.cfg.GeminiKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} + func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) { var attrs map[string]string if auth != nil { diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index df8ee506b..f8f4a63a5 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - from := opts.SourceFormat to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) @@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", req.Model) action := "generateContent" if req.Metadata != nil { @@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } } baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) action := "generateContent" if req.Metadata != nil { @@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action) + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - from := opts.SourceFormat to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) @@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", req.Model) baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) // For API key auth, use simpler URL format without project/location if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth // countTokensWithServiceAccount counts tokens using service account credentials. func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - from := opts.SourceFormat to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) @@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context } translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) if errNewReq != nil { @@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context // countTokensWithAPIKey handles token counting using API key credentials. func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(model, translatedReq) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", model) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") @@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens") httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) if errNewReq != nil { @@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau } return tok.AccessToken, nil } + +// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration. +// It matches the requested model alias against configured models and returns the actual upstream name. +func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string { + trimmed := strings.TrimSpace(alias) + if trimmed == "" { + return "" + } + + entry := e.resolveVertexConfig(auth) + if entry == nil { + return "" + } + + normalizedModel, metadata := util.NormalizeThinkingModel(trimmed) + + // Candidate names to match against configured aliases/names. + candidates := []string{strings.TrimSpace(normalizedModel)} + if !strings.EqualFold(normalizedModel, trimmed) { + candidates = append(candidates, trimmed) + } + if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) { + candidates = append(candidates, original) + } + + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + modelAlias := strings.TrimSpace(model.Alias) + + for _, candidate := range candidates { + if candidate == "" { + continue + } + if modelAlias != "" && strings.EqualFold(modelAlias, candidate) { + if name != "" { + return name + } + return candidate + } + if name != "" && strings.EqualFold(name, candidate) { + return name + } + } + } + return "" +} + +// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth. +func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey { + if auth == nil || e.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range e.cfg.VertexCompatAPIKey { + entry := &e.cfg.VertexCompatAPIKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range e.cfg.VertexCompatAPIKey { + entry := &e.cfg.VertexCompatAPIKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 124a984e6..49fd4eb7a 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -58,12 +58,9 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re to := sdktranslator.FromString("openai") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return resp, errValidate } body = applyIFlowThinkingConfig(body) @@ -151,12 +148,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return nil, errValidate } body = applyIFlowThinkingConfig(body) diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 1c57c9b7b..81fc31a15 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) allowCompat := e.allowCompatReasoningEffort(req.Model, auth) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" && modelOverride == "" { - translated, _ = sjson.SetBytes(translated, "model", upstreamModel) - } - translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat) - if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil { + translated = NormalizeThinkingConfig(translated, req.Model, allowCompat) + if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil { return resp, errValidate } @@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) allowCompat := e.allowCompatReasoningEffort(req.Model, auth) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" && modelOverride == "" { - translated, _ = sjson.SetBytes(translated, "model", upstreamModel) - } - translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat) - if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil { + translated = NormalizeThinkingConfig(translated, req.Model, allowCompat) + if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil { return nil, errValidate } diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 1d4ef52d5..ff6fa414f 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -12,7 +12,6 @@ import ( qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -52,12 +51,9 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req to := sdktranslator.FromString("openai") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return resp, errValidate } body = applyPayloadConfig(e.cfg, req.Model, body) @@ -132,12 +128,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return nil, errValidate } toolsResult := gjson.GetBytes(body, "tools") diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index 27d2f9b66..5529d52a3 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -23,6 +23,7 @@ type geminiToResponsesState struct { MsgIndex int CurrentMsgID string TextBuf strings.Builder + ItemTextBuf strings.Builder // reasoning aggregation ReasoningOpened bool @@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID) partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex) out = append(out, emitEvent("response.content_part.added", partAdded)) + st.ItemTextBuf.Reset() + st.ItemTextBuf.WriteString(t.String()) } st.TextBuf.WriteString(t.String()) msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}` @@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, finalizeReasoning() // Close message output if opened if st.MsgOpened { + fullText := st.ItemTextBuf.String() done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` done, _ = sjson.Set(done, "sequence_number", nextSeq()) done, _ = sjson.Set(done, "item_id", st.CurrentMsgID) done, _ = sjson.Set(done, "output_index", st.MsgIndex) + done, _ = sjson.Set(done, "text", fullText) out = append(out, emitEvent("response.output_text.done", done)) partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID) partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex) + partDone, _ = sjson.Set(partDone, "part.text", fullText) out = append(out, emitEvent("response.content_part.done", partDone)) final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}` final, _ = sjson.Set(final, "sequence_number", nextSeq()) final, _ = sjson.Set(final, "output_index", st.MsgIndex) final, _ = sjson.Set(final, "item.id", st.CurrentMsgID) + final, _ = sjson.Set(final, "item.content.0.text", fullText) out = append(out, emitEvent("response.output_item.done", final)) } diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index 1ce601516..e24fc893d 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -90,6 +90,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i)) } + oldModels := SummarizeGeminiModels(o.Models) + newModels := SummarizeGeminiModels(n.Models) + if oldModels.hash != newModels.hash { + changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count)) + } oldExcluded := SummarizeExcludedModels(o.ExcludedModels) newExcluded := SummarizeExcludedModels(n.ExcludedModels) if oldExcluded.hash != newExcluded.hash { @@ -120,6 +125,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i)) } + oldModels := SummarizeClaudeModels(o.Models) + newModels := SummarizeClaudeModels(n.Models) + if oldModels.hash != newModels.hash { + changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count)) + } oldExcluded := SummarizeExcludedModels(o.ExcludedModels) newExcluded := SummarizeExcludedModels(n.ExcludedModels) if oldExcluded.hash != newExcluded.hash { @@ -150,6 +160,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i)) } + oldModels := SummarizeCodexModels(o.Models) + newModels := SummarizeCodexModels(n.Models) + if oldModels.hash != newModels.hash { + changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count)) + } oldExcluded := SummarizeExcludedModels(o.ExcludedModels) newExcluded := SummarizeExcludedModels(n.ExcludedModels) if oldExcluded.hash != newExcluded.hash { @@ -194,6 +209,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { changes = append(changes, entries...) } + if entries, _ := DiffOAuthModelMappingChanges(oldCfg.OAuthModelMappings, newCfg.OAuthModelMappings); len(entries) > 0 { + changes = append(changes, entries...) + } // Remote management (never print the key) if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote { diff --git a/internal/watcher/diff/model_hash.go b/internal/watcher/diff/model_hash.go index a224bdcaa..5779faccd 100644 --- a/internal/watcher/diff/model_hash.go +++ b/internal/watcher/diff/model_hash.go @@ -71,6 +71,21 @@ func ComputeCodexModelsHash(models []config.CodexModel) string { return hashJoined(keys) } +// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases. +func ComputeGeminiModelsHash(models []config.GeminiModel) string { + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return hashJoined(keys) +} + // ComputeExcludedModelsHash returns a normalized hash for excluded model lists. func ComputeExcludedModelsHash(excluded []string) string { if len(excluded) == 0 { diff --git a/internal/watcher/diff/models_summary.go b/internal/watcher/diff/models_summary.go new file mode 100644 index 000000000..9c2aa91ac --- /dev/null +++ b/internal/watcher/diff/models_summary.go @@ -0,0 +1,121 @@ +package diff + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +type GeminiModelsSummary struct { + hash string + count int +} + +type ClaudeModelsSummary struct { + hash string + count int +} + +type CodexModelsSummary struct { + hash string + count int +} + +type VertexModelsSummary struct { + hash string + count int +} + +// SummarizeGeminiModels hashes Gemini model aliases for change detection. +func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary { + if len(models) == 0 { + return GeminiModelsSummary{} + } + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return GeminiModelsSummary{ + hash: hashJoined(keys), + count: len(keys), + } +} + +// SummarizeClaudeModels hashes Claude model aliases for change detection. +func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary { + if len(models) == 0 { + return ClaudeModelsSummary{} + } + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return ClaudeModelsSummary{ + hash: hashJoined(keys), + count: len(keys), + } +} + +// SummarizeCodexModels hashes Codex model aliases for change detection. +func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary { + if len(models) == 0 { + return CodexModelsSummary{} + } + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return CodexModelsSummary{ + hash: hashJoined(keys), + count: len(keys), + } +} + +// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection. +func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary { + if len(models) == 0 { + return VertexModelsSummary{} + } + names := make([]string, 0, len(models)) + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + if alias != "" { + name = alias + } + names = append(names, name) + } + if len(names) == 0 { + return VertexModelsSummary{} + } + sort.Strings(names) + sum := sha256.Sum256([]byte(strings.Join(names, "|"))) + return VertexModelsSummary{ + hash: hex.EncodeToString(sum[:]), + count: len(names), + } +} diff --git a/internal/watcher/diff/oauth_excluded.go b/internal/watcher/diff/oauth_excluded.go index 4f08c4d64..2039cf489 100644 --- a/internal/watcher/diff/oauth_excluded.go +++ b/internal/watcher/diff/oauth_excluded.go @@ -116,36 +116,3 @@ func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappin count: len(entries), } } - -type VertexModelsSummary struct { - hash string - count int -} - -// SummarizeVertexModels hashes vertex-compatible models for change detection. -func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary { - if len(models) == 0 { - return VertexModelsSummary{} - } - names := make([]string, 0, len(models)) - for _, m := range models { - name := strings.TrimSpace(m.Name) - alias := strings.TrimSpace(m.Alias) - if name == "" && alias == "" { - continue - } - if alias != "" { - name = alias - } - names = append(names, name) - } - if len(names) == 0 { - return VertexModelsSummary{} - } - sort.Strings(names) - sum := sha256.Sum256([]byte(strings.Join(names, "|"))) - return VertexModelsSummary{ - hash: hex.EncodeToString(sum[:]), - count: len(names), - } -} diff --git a/internal/watcher/diff/oauth_model_mappings.go b/internal/watcher/diff/oauth_model_mappings.go new file mode 100644 index 000000000..9228dbab6 --- /dev/null +++ b/internal/watcher/diff/oauth_model_mappings.go @@ -0,0 +1,98 @@ +package diff + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +type OAuthModelMappingsSummary struct { + hash string + count int +} + +// SummarizeOAuthModelMappings summarizes OAuth model mappings per channel. +func SummarizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string]OAuthModelMappingsSummary { + if len(entries) == 0 { + return nil + } + out := make(map[string]OAuthModelMappingsSummary, len(entries)) + for k, v := range entries { + key := strings.ToLower(strings.TrimSpace(k)) + if key == "" { + continue + } + out[key] = summarizeOAuthModelMappingList(v) + } + if len(out) == 0 { + return nil + } + return out +} + +// DiffOAuthModelMappingChanges compares OAuth model mappings maps. +func DiffOAuthModelMappingChanges(oldMap, newMap map[string][]config.ModelNameMapping) ([]string, []string) { + oldSummary := SummarizeOAuthModelMappings(oldMap) + newSummary := SummarizeOAuthModelMappings(newMap) + keys := make(map[string]struct{}, len(oldSummary)+len(newSummary)) + for k := range oldSummary { + keys[k] = struct{}{} + } + for k := range newSummary { + keys[k] = struct{}{} + } + changes := make([]string, 0, len(keys)) + affected := make([]string, 0, len(keys)) + for key := range keys { + oldInfo, okOld := oldSummary[key] + newInfo, okNew := newSummary[key] + switch { + case okOld && !okNew: + changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: removed", key)) + affected = append(affected, key) + case !okOld && okNew: + changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: added (%d entries)", key, newInfo.count)) + affected = append(affected, key) + case okOld && okNew && oldInfo.hash != newInfo.hash: + changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count)) + affected = append(affected, key) + } + } + sort.Strings(changes) + sort.Strings(affected) + return changes, affected +} + +func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMappingsSummary { + if len(list) == 0 { + return OAuthModelMappingsSummary{} + } + seen := make(map[string]struct{}, len(list)) + normalized := make([]string, 0, len(list)) + for _, mapping := range list { + name := strings.ToLower(strings.TrimSpace(mapping.Name)) + alias := strings.ToLower(strings.TrimSpace(mapping.Alias)) + if name == "" || alias == "" { + continue + } + key := name + "->" + alias + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + normalized = append(normalized, key) + } + if len(normalized) == 0 { + return OAuthModelMappingsSummary{} + } + sort.Strings(normalized) + sum := sha256.Sum256([]byte(strings.Join(normalized, "|"))) + return OAuthModelMappingsSummary{ + hash: hex.EncodeToString(sum[:]), + count: len(normalized), + } +} diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go index ece7c29c8..e976af4e1 100644 --- a/internal/watcher/synthesizer/config.go +++ b/internal/watcher/synthesizer/config.go @@ -66,6 +66,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea if base != "" { attrs["base_url"] = base } + if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" { + attrs["models_hash"] = hash + } addConfigHeadersToAttrs(entry.Headers, attrs) a := &coreauth.Auth{ ID: id, diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 27e940e8c..b150d80f6 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { rerr := &Error{Message: errStream.Error()} diff --git a/sdk/cliproxy/auth/model_name_mappings.go b/sdk/cliproxy/auth/model_name_mappings.go index 483cb9c90..f1b31aa51 100644 --- a/sdk/cliproxy/auth/model_name_mappings.go +++ b/sdk/cliproxy/auth/model_name_mappings.go @@ -65,17 +65,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod m.modelNameMappings.Store(table) } -func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { - original := m.resolveOAuthUpstreamModel(auth, requestedModel) - if original == "" { - return metadata - } - if metadata != nil { - if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { - if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { - return metadata - } - } +// applyOAuthModelMapping resolves the upstream model from OAuth model mappings +// and returns the resolved model along with updated metadata. If a mapping exists, +// the returned model is the upstream model and metadata contains the original +// requested model for response translation. +func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) { + upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel) + if upstreamModel == "" { + return requestedModel, metadata } out := make(map[string]any, 1) if len(metadata) > 0 { @@ -84,8 +81,8 @@ func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel stri out[k] = v } } - out[util.ModelMappingOriginalModelMetadataKey] = original - return out + out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel + return upstreamModel, out } func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 0927eaa64..f249e95c8 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -714,6 +714,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "gemini": models = registry.GetGeminiModels() if entry := s.resolveConfigGeminiKey(a); entry != nil { + if len(entry.Models) > 0 { + models = buildGeminiConfigModels(entry) + } if authKind == "apikey" { excluded = entry.ExcludedModels } @@ -1125,17 +1128,22 @@ func matchWildcard(pattern, value string) bool { return true } -func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { - if entry == nil || len(entry.Models) == 0 { +type modelEntry interface { + GetName() string + GetAlias() string +} + +func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo { + if len(models) == 0 { return nil } now := time.Now().Unix() - out := make([]*ModelInfo, 0, len(entry.Models)) - seen := make(map[string]struct{}, len(entry.Models)) - for i := range entry.Models { - model := entry.Models[i] - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) + out := make([]*ModelInfo, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for i := range models { + model := models[i] + name := strings.TrimSpace(model.GetName()) + alias := strings.TrimSpace(model.GetAlias()) if alias == "" { alias = name } @@ -1151,18 +1159,52 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { if display == "" { display = alias } - out = append(out, &ModelInfo{ + info := &ModelInfo{ ID: alias, Object: "model", Created: now, - OwnedBy: "vertex", - Type: "vertex", + OwnedBy: ownedBy, + Type: modelType, DisplayName: display, - }) + } + if name != "" { + if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil { + info.Thinking = upstream.Thinking + } + } + out = append(out, info) } return out } +func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "google", "vertex") +} + +func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "google", "gemini") +} + +func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "anthropic", "claude") +} + +func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "openai", "openai") +} + func rewriteModelInfoName(name, oldID, newID string) string { trimmed := strings.TrimSpace(name) if trimmed == "" { @@ -1249,79 +1291,3 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode } return out } - -func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { - if entry == nil || len(entry.Models) == 0 { - return nil - } - now := time.Now().Unix() - out := make([]*ModelInfo, 0, len(entry.Models)) - seen := make(map[string]struct{}, len(entry.Models)) - for i := range entry.Models { - model := entry.Models[i] - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) - if alias == "" { - alias = name - } - if alias == "" { - continue - } - key := strings.ToLower(alias) - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - display := name - if display == "" { - display = alias - } - out = append(out, &ModelInfo{ - ID: alias, - Object: "model", - Created: now, - OwnedBy: "claude", - Type: "claude", - DisplayName: display, - }) - } - return out -} - -func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo { - if entry == nil || len(entry.Models) == 0 { - return nil - } - now := time.Now().Unix() - out := make([]*ModelInfo, 0, len(entry.Models)) - seen := make(map[string]struct{}, len(entry.Models)) - for i := range entry.Models { - model := entry.Models[i] - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) - if alias == "" { - alias = name - } - if alias == "" { - continue - } - key := strings.ToLower(alias) - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - display := name - if display == "" { - display = alias - } - out = append(out, &ModelInfo{ - ID: alias, - Object: "model", - Created: now, - OwnedBy: "openai", - Type: "openai", - DisplayName: display, - }) - } - return out -}