Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions internal/api/middleware/request_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
}

return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
URL: url,
Method: method,
Headers: headers,
Body: body,
RequestID: logging.GetGinRequestID(c),
}, nil
}

Expand Down
14 changes: 9 additions & 5 deletions internal/api/middleware/response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import (

// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
RequestID string // RequestID is the unique identifier for the request.
}

// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
Expand Down Expand Up @@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
w.requestInfo.RequestID,
)
if err == nil {
w.streamWriter = streamWriter
Expand Down Expand Up @@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
}

if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
Expand All @@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiResponseBody,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
)
}

Expand All @@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiRequestBody,
apiResponseBody,
apiResponseErrors,
w.requestInfo.RequestID,
)
}
15 changes: 11 additions & 4 deletions internal/api/modules/amp/amp.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
return true
}

// Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings))
// Build map for efficient and robust comparison
type mappingInfo struct {
to string
regex bool
}
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
to: strings.TrimSpace(mapping.To),
regex: mapping.Regex,
}
}

for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
return true
}
}
Comment on lines +282 to 301

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

当前的 hasModelMappingsChanged 实现使用 map 来比较新旧模型映射,这会导致正则表达式映射的顺序信息丢失。根据 internal/api/modules/amp/model_mapping.go 中的实现,正则表达式的匹配顺序至关重要。

如果用户在配置文件中仅仅调整了正则表达式规则的顺序,当前的实现将无法检测到这一变更,导致路由逻辑不会更新,这可能会引发难以调试的非预期路由行为。

建议修改此函数,以确保在比较时能考虑到规则的顺序。一个更简单且健壮的方法是直接按顺序比较 ModelMappings 切片中的每个元素。

	// Direct comparison is needed to respect the order of regex mappings.
	for i := range new.ModelMappings {
		oldMapping := old.ModelMappings[i]
		newMapping := new.ModelMappings[i]

		if strings.TrimSpace(oldMapping.From) != strings.TrimSpace(newMapping.From) ||
			strings.TrimSpace(oldMapping.To) != strings.TrimSpace(newMapping.To) ||
			oldMapping.Regex != newMapping.Regex {
			return true
		}
	}

Expand Down
48 changes: 41 additions & 7 deletions internal/api/modules/amp/model_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package amp

import (
"regexp"
"strings"
"sync"

Expand All @@ -26,13 +27,15 @@ type ModelMapper interface {
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys)
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
}

// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
}
m.UpdateMappings(mappings)
return m
Expand All @@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
// Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest]
if !exists {
return ""
// Try regex mappings in order
base, _ := util.NormalizeThinkingModel(requestedModel)
for _, rm := range m.regexps {
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
}

// Verify target model has available providers
Expand All @@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {

// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))

for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
Expand All @@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
continue
}

// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to

log.Debugf("amp model mapping registered: %s -> %s", from, to)
if mapping.Regex {
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
}

if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
}

// GetMappings returns a copy of current mappings (for debugging/status).
Expand All @@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
}
return result
}

type regexMapping struct {
re *regexp.Regexp
to string
}
78 changes: 78 additions & 0 deletions internal/api/modules/amp/model_mapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
t.Error("Original map was modified")
}
}

func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")

mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}

mapper := NewModelMapper(mappings)

// Incoming model has reasoning suffix but should match base via regex
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro" {
t.Errorf("Expected gemini-2.5-pro, got %s", result)
}
}

func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")

mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}

mapper := NewModelMapper(mappings)

// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}

func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}

mapper := NewModelMapper(mappings)

result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}

func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")

mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}

mapper := NewModelMapper(mappings)

result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
7 changes: 6 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ type AmpModelMapping struct {
// To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry.
To string `yaml:"to" json:"to"`

// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
}

// AmpCode groups Amp CLI integration settings including upstream routing,
Expand Down Expand Up @@ -401,7 +406,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.DisableCooling = false
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
Expand Down
52 changes: 46 additions & 6 deletions internal/logging/gin_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,32 @@ import (
"fmt"
"net/http"
"runtime/debug"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)

// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
"/api/provider/",
}

const skipGinLogKey = "__gin_skip_request_logging__"

// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages, formatting them in a Gin-style log format.
// client IP, and any error messages. Request ID is only added for AI API requests.
//
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
//
// Returns:
// - gin.HandlerFunc: A middleware handler for request logging
Expand All @@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
path := c.Request.URL.Path
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)

// Only generate request ID for AI API paths
var requestID string
if isAIAPIPath(path) {
requestID = GenerateRequestID()
SetGinRequestID(c, requestID)
ctx := WithRequestID(c.Request.Context(), requestID)
c.Request = c.Request.WithContext(ctx)
}

c.Next()

if shouldSkipGinRequestLogging(c) {
Expand All @@ -49,21 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc {
clientIP := c.ClientIP()
method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)

logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" {
logLine = logLine + " | " + errorMessage
}

var entry *log.Entry
if requestID != "" {
entry = log.WithField("request_id", requestID)
} else {
entry = log.WithField("request_id", "--------")
}

switch {
case statusCode >= http.StatusInternalServerError:
log.Error(logLine)
entry.Error(logLine)
case statusCode >= http.StatusBadRequest:
log.Warn(logLine)
entry.Warn(logLine)
default:
log.Info(logLine)
entry.Info(logLine)
}
}
}

// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
func isAIAPIPath(path string) bool {
for _, prefix := range aiAPIPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}

// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
Expand Down
Loading