diff --git a/config.example.yaml b/config.example.yaml index 67d40629c..f6f84c6ee 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -221,6 +221,7 @@ ws-auth: false # gemini-cli: # - name: "gemini-2.5-pro" # original model name under this channel # alias: "g2.5p" # client-visible alias +# fork: true # when true, keep original and also add the alias as an extra model (default: false) # vertex: # - name: "gemini-2.5-pro" # alias: "g2.5p" diff --git a/internal/config/config.go b/internal/config/config.go index 7c30c4f9f..0cd89dc4a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -157,11 +157,14 @@ type RoutingConfig struct { Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` } -// ModelNameMapping defines a model ID rename mapping for a specific channel. -// It maps the original model name (Name) to the client-visible alias (Alias). +// ModelNameMapping defines a model ID mapping for a specific channel. +// It maps the upstream model name (Name) to the client-visible alias (Alias). +// When Fork is true, the alias is added as an additional model in listings while +// keeping the original model ID available. type ModelNameMapping struct { Name string `yaml:"name" json:"name"` Alias string `yaml:"alias" json:"alias"` + Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"` } // AmpModelMapping defines a model name mapping for Amp CLI requests. @@ -596,7 +599,7 @@ func (cfg *Config) SanitizeOAuthModelMappings() { } seenName[nameKey] = struct{}{} seenAlias[aliasKey] = struct{}{} - clean = append(clean, ModelNameMapping{Name: name, Alias: alias}) + clean = append(clean, ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork}) } if len(clean) > 0 { out[channel] = clean diff --git a/internal/config/oauth_model_mappings_test.go b/internal/config/oauth_model_mappings_test.go new file mode 100644 index 000000000..7b801a792 --- /dev/null +++ b/internal/config/oauth_model_mappings_test.go @@ -0,0 +1,27 @@ +package config + +import "testing" + +func TestSanitizeOAuthModelMappings_PreservesForkFlag(t *testing.T) { + cfg := &Config{ + OAuthModelMappings: map[string][]ModelNameMapping{ + " CoDeX ": { + {Name: " gpt-5 ", Alias: " g5 ", Fork: true}, + {Name: "gpt-6", Alias: "g6"}, + }, + }, + } + + cfg.SanitizeOAuthModelMappings() + + mappings := cfg.OAuthModelMappings["codex"] + if len(mappings) != 2 { + t.Fatalf("expected 2 sanitized mappings, got %d", len(mappings)) + } + if mappings[0].Name != "gpt-5" || mappings[0].Alias != "g5" || !mappings[0].Fork { + t.Fatalf("expected first mapping to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", mappings[0].Name, mappings[0].Alias, mappings[0].Fork) + } + if mappings[1].Name != "gpt-6" || mappings[1].Alias != "g6" || mappings[1].Fork { + t.Fatalf("expected second mapping to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", mappings[1].Name, mappings[1].Alias, mappings[1].Fork) + } +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index f0116faa8..0baba498d 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -45,6 +45,7 @@ const ( defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64" antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second + tokenRefreshTimeout = 30 * time.Second ) var ( @@ -914,7 +915,13 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) { return accessToken, nil, nil } - updated, errRefresh := e.refreshToken(ctx, auth.Clone()) + refreshCtx := context.Background() + if ctx != nil { + if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil { + refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt) + } + } + updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone()) if errRefresh != nil { return "", nil, errRefresh } @@ -944,7 +951,7 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau httpReq.Header.Set("User-Agent", defaultAntigravityAgent) httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, tokenRefreshTimeout) httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { return auth, errDo diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go index 27ab082bb..1629545d2 100644 --- a/internal/translator/openai/claude/openai_claude_response.go +++ b/internal/translator/openai/claude/openai_claude_response.go @@ -299,17 +299,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI inputTokens = promptTokens.Int() outputTokens = completionTokens.Int() } + // Send message_delta with usage + messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason)) + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens) + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens) + results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n") + param.MessageDeltaSent = true + + emitMessageStopIfNeeded(param, &results) } - // Send message_delta with usage - messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason)) - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens) - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens) - results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n") - param.MessageDeltaSent = true - - emitMessageStopIfNeeded(param, &results) - } return results diff --git a/internal/watcher/diff/oauth_model_mappings.go b/internal/watcher/diff/oauth_model_mappings.go index 9228dbab6..c002855cf 100644 --- a/internal/watcher/diff/oauth_model_mappings.go +++ b/internal/watcher/diff/oauth_model_mappings.go @@ -80,6 +80,9 @@ func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMa continue } key := name + "->" + alias + if mapping.Fork { + key += "|fork" + } if _, exists := seen[key]; exists { continue } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index b150d80f6..698d0102e 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -1536,6 +1536,9 @@ func (m *Manager) markRefreshPending(id string, now time.Time) bool { } func (m *Manager) refreshAuth(ctx context.Context, id string) { + if ctx == nil { + ctx = context.Background() + } m.mu.RLock() auth := m.auths[id] var exec ProviderExecutor @@ -1548,6 +1551,10 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { } cloned := auth.Clone() updated, err := exec.Refresh(ctx, cloned) + if err != nil && errors.Is(err, context.Canceled) { + log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID) + return + } log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err) now := time.Now() if err != nil { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index f249e95c8..9c094c8ce 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -1240,7 +1240,13 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode if len(mappings) == 0 { return models } - forward := make(map[string]string, len(mappings)) + + type mappingEntry struct { + alias string + fork bool + } + + forward := make(map[string]mappingEntry, len(mappings)) for i := range mappings { name := strings.TrimSpace(mappings[i].Name) alias := strings.TrimSpace(mappings[i].Alias) @@ -1254,7 +1260,7 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode if _, exists := forward[key]; exists { continue } - forward[key] = alias + forward[key] = mappingEntry{alias: alias, fork: mappings[i].Fork} } if len(forward) == 0 { return models @@ -1269,10 +1275,45 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode if id == "" { continue } - mappedID := id - if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" { - mappedID = strings.TrimSpace(to) + key := strings.ToLower(id) + entry, ok := forward[key] + if !ok { + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + out = append(out, model) + continue } + mappedID := strings.TrimSpace(entry.alias) + if mappedID == "" { + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + out = append(out, model) + continue + } + + if entry.fork { + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + out = append(out, model) + } + aliasKey := strings.ToLower(mappedID) + if _, exists := seen[aliasKey]; exists { + continue + } + seen[aliasKey] = struct{}{} + clone := *model + clone.ID = mappedID + if clone.Name != "" { + clone.Name = rewriteModelInfoName(clone.Name, id, mappedID) + } + out = append(out, &clone) + continue + } + uniqueKey := strings.ToLower(mappedID) if _, exists := seen[uniqueKey]; exists { continue diff --git a/sdk/cliproxy/service_oauth_model_mappings_test.go b/sdk/cliproxy/service_oauth_model_mappings_test.go new file mode 100644 index 000000000..7d8da08a8 --- /dev/null +++ b/sdk/cliproxy/service_oauth_model_mappings_test.go @@ -0,0 +1,58 @@ +package cliproxy + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestApplyOAuthModelMappings_Rename(t *testing.T) { + cfg := &config.Config{ + OAuthModelMappings: map[string][]config.ModelNameMapping{ + "codex": { + {Name: "gpt-5", Alias: "g5"}, + }, + }, + } + models := []*ModelInfo{ + {ID: "gpt-5", Name: "models/gpt-5"}, + } + + out := applyOAuthModelMappings(cfg, "codex", "oauth", models) + if len(out) != 1 { + t.Fatalf("expected 1 model, got %d", len(out)) + } + if out[0].ID != "g5" { + t.Fatalf("expected model id %q, got %q", "g5", out[0].ID) + } + if out[0].Name != "models/g5" { + t.Fatalf("expected model name %q, got %q", "models/g5", out[0].Name) + } +} + +func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) { + cfg := &config.Config{ + OAuthModelMappings: map[string][]config.ModelNameMapping{ + "codex": { + {Name: "gpt-5", Alias: "g5", Fork: true}, + }, + }, + } + models := []*ModelInfo{ + {ID: "gpt-5", Name: "models/gpt-5"}, + } + + out := applyOAuthModelMappings(cfg, "codex", "oauth", models) + if len(out) != 2 { + t.Fatalf("expected 2 models, got %d", len(out)) + } + if out[0].ID != "gpt-5" { + t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID) + } + if out[1].ID != "g5" { + t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID) + } + if out[1].Name != "models/g5" { + t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name) + } +}