Skip to content

Commit 5418bbc

Browse files
authored
Merge branch 'router-for-me:main' into main
2 parents 8fac6b1 + 89254cf commit 5418bbc

File tree

9 files changed

+247
-34
lines changed

9 files changed

+247
-34
lines changed

internal/api/modules/amp/fallback_handlers.go

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
134134
}
135135

136136
// Normalize model (handles dynamic thinking suffixes)
137-
normalizedModel, _ := util.NormalizeThinkingModel(modelName)
137+
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
138+
thinkingSuffix := ""
139+
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
140+
thinkingSuffix = modelName[len(normalizedModel):]
141+
}
142+
143+
resolveMappedModel := func() (string, []string) {
144+
if fh.modelMapper == nil {
145+
return "", nil
146+
}
147+
148+
mappedModel := fh.modelMapper.MapModel(modelName)
149+
if mappedModel == "" {
150+
mappedModel = fh.modelMapper.MapModel(normalizedModel)
151+
}
152+
mappedModel = strings.TrimSpace(mappedModel)
153+
if mappedModel == "" {
154+
return "", nil
155+
}
156+
157+
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
158+
// already specifies its own thinking suffix.
159+
if thinkingSuffix != "" {
160+
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
161+
if mappedThinkingMetadata == nil {
162+
mappedModel += thinkingSuffix
163+
}
164+
}
165+
166+
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
167+
mappedProviders := util.GetProviderName(mappedBaseModel)
168+
if len(mappedProviders) == 0 {
169+
return "", nil
170+
}
171+
172+
return mappedModel, mappedProviders
173+
}
138174

139175
// Track resolved model for logging (may change if mapping is applied)
140176
resolvedModel := normalizedModel
@@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
147183
if forceMappings {
148184
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
149185
// This allows users to route Amp requests to their preferred OAuth providers
150-
if fh.modelMapper != nil {
151-
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
152-
// Mapping found - check if we have a provider for the mapped model
153-
mappedProviders := util.GetProviderName(mappedModel)
154-
if len(mappedProviders) > 0 {
155-
// Mapping found and provider available - rewrite the model in request body
156-
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
157-
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
158-
// Store mapped model in context for handlers that check it (like gemini bridge)
159-
c.Set(MappedModelContextKey, mappedModel)
160-
resolvedModel = mappedModel
161-
usedMapping = true
162-
providers = mappedProviders
163-
}
164-
}
186+
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
187+
// Mapping found and provider available - rewrite the model in request body
188+
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
189+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
190+
// Store mapped model in context for handlers that check it (like gemini bridge)
191+
c.Set(MappedModelContextKey, mappedModel)
192+
resolvedModel = mappedModel
193+
usedMapping = true
194+
providers = mappedProviders
165195
}
166196

167197
// If no mapping applied, check for local providers
@@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
174204

175205
if len(providers) == 0 {
176206
// No providers configured - check if we have a model mapping
177-
if fh.modelMapper != nil {
178-
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
179-
// Mapping found - check if we have a provider for the mapped model
180-
mappedProviders := util.GetProviderName(mappedModel)
181-
if len(mappedProviders) > 0 {
182-
// Mapping found and provider available - rewrite the model in request body
183-
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
184-
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
185-
// Store mapped model in context for handlers that check it (like gemini bridge)
186-
c.Set(MappedModelContextKey, mappedModel)
187-
resolvedModel = mappedModel
188-
usedMapping = true
189-
providers = mappedProviders
190-
}
191-
}
207+
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
208+
// Mapping found and provider available - rewrite the model in request body
209+
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
210+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
211+
// Store mapped model in context for handlers that check it (like gemini bridge)
212+
c.Set(MappedModelContextKey, mappedModel)
213+
resolvedModel = mappedModel
214+
usedMapping = true
215+
providers = mappedProviders
192216
}
193217
}
194218
}
@@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
222246
// Log: Model was mapped to another model
223247
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
224248
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
225-
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
249+
rewriter := NewResponseRewriter(c.Writer, modelName)
226250
c.Writer = rewriter
227251
// Filter Anthropic-Beta header only for local handling paths
228252
filterAntropicBetaHeader(c)
229253
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
230254
handler(c)
231255
rewriter.Flush()
232-
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
256+
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
233257
} else if len(providers) > 0 {
234258
// Log: Using local provider (free)
235259
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package amp
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"net/http/httputil"
9+
"testing"
10+
11+
"github.com/gin-gonic/gin"
12+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
13+
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
14+
)
15+
16+
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
17+
gin.SetMode(gin.TestMode)
18+
19+
reg := registry.GetGlobalRegistry()
20+
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
21+
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
22+
})
23+
defer reg.UnregisterClient("test-client-amp-fallback")
24+
25+
mapper := NewModelMapper([]config.AmpModelMapping{
26+
{From: "gpt-5.2", To: "test/gpt-5.2"},
27+
})
28+
29+
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
30+
31+
handler := func(c *gin.Context) {
32+
var req struct {
33+
Model string `json:"model"`
34+
}
35+
if err := c.ShouldBindJSON(&req); err != nil {
36+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
37+
return
38+
}
39+
40+
c.JSON(http.StatusOK, gin.H{
41+
"model": req.Model,
42+
"seen_model": req.Model,
43+
})
44+
}
45+
46+
r := gin.New()
47+
r.POST("/chat/completions", fallback.WrapHandler(handler))
48+
49+
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
50+
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
51+
req.Header.Set("Content-Type", "application/json")
52+
w := httptest.NewRecorder()
53+
r.ServeHTTP(w, req)
54+
55+
if w.Code != http.StatusOK {
56+
t.Fatalf("Expected status 200, got %d", w.Code)
57+
}
58+
59+
var resp struct {
60+
Model string `json:"model"`
61+
SeenModel string `json:"seen_model"`
62+
}
63+
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
64+
t.Fatalf("Failed to parse response JSON: %v", err)
65+
}
66+
67+
if resp.Model != "gpt-5.2(xhigh)" {
68+
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
69+
}
70+
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
71+
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
72+
}
73+
}

internal/api/modules/amp/model_mapping.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
5959
}
6060

6161
// Verify target model has available providers
62-
providers := util.GetProviderName(targetModel)
62+
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
63+
providers := util.GetProviderName(normalizedTarget)
6364
if len(providers) == 0 {
6465
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
6566
return ""

internal/api/modules/amp/model_mapping_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
7171
}
7272
}
7373

74+
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
75+
reg := registry.GetGlobalRegistry()
76+
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
77+
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
78+
})
79+
defer reg.UnregisterClient("test-client-thinking")
80+
81+
mappings := []config.AmpModelMapping{
82+
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
83+
}
84+
85+
mapper := NewModelMapper(mappings)
86+
87+
result := mapper.MapModel("gpt-5.2-alias")
88+
if result != "gpt-5.2(xhigh)" {
89+
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
90+
}
91+
}
92+
7493
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
7594
reg := registry.GetGlobalRegistry()
7695
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{

internal/registry/model_definitions.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,21 @@ func GetGeminiModels() []*ModelInfo {
162162
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
163163
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
164164
},
165+
{
166+
ID: "gemini-3-flash-preview",
167+
Object: "model",
168+
Created: 1765929600,
169+
OwnedBy: "google",
170+
Type: "gemini",
171+
Name: "models/gemini-3-flash-preview",
172+
Version: "3.0",
173+
DisplayName: "Gemini 3 Flash Preview",
174+
Description: "Gemini 3 Flash Preview",
175+
InputTokenLimit: 1048576,
176+
OutputTokenLimit: 65536,
177+
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
178+
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
179+
},
165180
{
166181
ID: "gemini-3-pro-image-preview",
167182
Object: "model",

internal/runtime/executor/antigravity_executor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
9797
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
9898
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
9999
translated = normalizeAntigravityThinking(req.Model, translated)
100+
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
100101

101102
baseURLs := antigravityBaseURLFallbackOrder(auth)
102103
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -191,6 +192,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
191192
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
192193
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
193194
translated = normalizeAntigravityThinking(req.Model, translated)
195+
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
194196

195197
baseURLs := antigravityBaseURLFallbackOrder(auth)
196198
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -524,6 +526,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
524526
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
525527
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
526528
translated = normalizeAntigravityThinking(req.Model, translated)
529+
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
527530

528531
baseURLs := antigravityBaseURLFallbackOrder(auth)
529532
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)

internal/util/gemini_schema.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ func flattenTypeArrays(jsonStr string) string {
296296
func removeUnsupportedKeywords(jsonStr string) string {
297297
keywords := append(unsupportedConstraints,
298298
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
299+
"propertyNames", // Gemini doesn't support property name validation
299300
)
300301
for _, key := range keywords {
301302
for _, p := range findPaths(jsonStr, key) {

internal/util/gemini_schema_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,71 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
596596
}
597597
}
598598

599+
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
600+
// propertyNames is used to validate object property names (e.g., must match a pattern)
601+
// Gemini doesn't support this keyword and will reject requests containing it
602+
input := `{
603+
"type": "object",
604+
"properties": {
605+
"metadata": {
606+
"type": "object",
607+
"propertyNames": {
608+
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
609+
},
610+
"additionalProperties": {
611+
"type": "string"
612+
}
613+
}
614+
}
615+
}`
616+
617+
expected := `{
618+
"type": "object",
619+
"properties": {
620+
"metadata": {
621+
"type": "object"
622+
}
623+
}
624+
}`
625+
626+
result := CleanJSONSchemaForGemini(input)
627+
compareJSON(t, expected, result)
628+
629+
// Verify propertyNames is completely removed
630+
if strings.Contains(result, "propertyNames") {
631+
t.Errorf("propertyNames keyword should be removed, got: %s", result)
632+
}
633+
}
634+
635+
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
636+
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
637+
input := `{
638+
"type": "object",
639+
"properties": {
640+
"items": {
641+
"type": "array",
642+
"items": {
643+
"type": "object",
644+
"properties": {
645+
"config": {
646+
"type": "object",
647+
"propertyNames": {
648+
"type": "string"
649+
}
650+
}
651+
}
652+
}
653+
}
654+
}
655+
}`
656+
657+
result := CleanJSONSchemaForGemini(input)
658+
659+
if strings.Contains(result, "propertyNames") {
660+
t.Errorf("Nested propertyNames should be removed, got: %s", result)
661+
}
662+
}
663+
599664
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
600665
var expMap, actMap map[string]interface{}
601666
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)

internal/util/gemini_thinking.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
136136
updated = rewritten
137137
}
138138
}
139+
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
140+
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
141+
}
142+
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
143+
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
144+
}
139145
return updated
140146
}
141147

@@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
167173
updated = rewritten
168174
}
169175
}
176+
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
177+
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
178+
}
179+
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
180+
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
181+
}
170182
return updated
171183
}
172184

0 commit comments

Comments
 (0)