diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index 8f29e1a16..49f28f524 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { } return &RequestInfo{ - URL: url, - Method: method, - Headers: headers, - Body: body, + URL: url, + Method: method, + Headers: headers, + Body: body, + RequestID: logging.GetGinRequestID(c), }, nil } diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index 8005df238..8029e50af 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -15,10 +15,11 @@ import ( // RequestInfo holds essential details of an incoming HTTP request for logging purposes. type RequestInfo struct { - URL string // URL is the request URL. - Method string // Method is the HTTP method (e.g., GET, POST). - Headers map[string][]string // Headers contains the request headers. - Body []byte // Body is the raw request body. + URL string // URL is the request URL. + Method string // Method is the HTTP method (e.g., GET, POST). + Headers map[string][]string // Headers contains the request headers. + Body []byte // Body is the raw request body. + RequestID string // RequestID is the unique identifier for the request. } // ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data. @@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) { w.requestInfo.Method, w.requestInfo.Headers, w.requestInfo.Body, + w.requestInfo.RequestID, ) if err == nil { w.streamWriter = streamWriter @@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][] } if loggerWithOptions, ok := w.logger.(interface { - LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error + LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error }); ok { return loggerWithOptions.LogRequestWithOptions( w.requestInfo.URL, @@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][] apiResponseBody, apiResponseErrors, forceLog, + w.requestInfo.RequestID, ) } @@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][] apiRequestBody, apiResponseBody, apiResponseErrors, + w.requestInfo.RequestID, ) } diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index c18657c9c..924b34529 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -279,16 +279,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp return true } - // Build map for efficient comparison - oldMap := make(map[string]string, len(old.ModelMappings)) + // Build map for efficient and robust comparison + type mappingInfo struct { + to string + regex bool + } + oldMap := make(map[string]mappingInfo, len(old.ModelMappings)) for _, mapping := range old.ModelMappings { - oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To) + oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{ + to: strings.TrimSpace(mapping.To), + regex: mapping.Regex, + } } for _, mapping := range new.ModelMappings { from := strings.TrimSpace(mapping.From) to := strings.TrimSpace(mapping.To) - if oldTo, exists := oldMap[from]; !exists || oldTo != to { + if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex { return true } } diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index bc31c4e56..4b629b629 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -3,6 +3,7 @@ package amp import ( + "regexp" "strings" "sync" @@ -26,13 +27,15 @@ type ModelMapper interface { // DefaultModelMapper implements ModelMapper with thread-safe mapping storage. type DefaultModelMapper struct { mu sync.RWMutex - mappings map[string]string // from -> to (normalized lowercase keys) + mappings map[string]string // exact: from -> to (normalized lowercase keys) + regexps []regexMapping // regex rules evaluated in order } // NewModelMapper creates a new model mapper with the given initial mappings. func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { m := &DefaultModelMapper{ mappings: make(map[string]string), + regexps: nil, } m.UpdateMappings(mappings) return m @@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { // Check for direct mapping targetModel, exists := m.mappings[normalizedRequest] if !exists { - return "" + // Try regex mappings in order + base, _ := util.NormalizeThinkingModel(requestedModel) + for _, rm := range m.regexps { + if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) { + targetModel = rm.to + exists = true + break + } + } + if !exists { + return "" + } } // Verify target model has available providers @@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { // Clear and rebuild mappings m.mappings = make(map[string]string, len(mappings)) + m.regexps = make([]regexMapping, 0, len(mappings)) for _, mapping := range mappings { from := strings.TrimSpace(mapping.From) @@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { continue } - // Store with normalized lowercase key for case-insensitive lookup - normalizedFrom := strings.ToLower(from) - m.mappings[normalizedFrom] = to - - log.Debugf("amp model mapping registered: %s -> %s", from, to) + if mapping.Regex { + // Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups + pattern := "(?i)" + from + re, err := regexp.Compile(pattern) + if err != nil { + log.Warnf("amp model mapping: invalid regex %q: %v", from, err) + continue + } + m.regexps = append(m.regexps, regexMapping{re: re, to: to}) + log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to) + } else { + // Store with normalized lowercase key for case-insensitive lookup + normalizedFrom := strings.ToLower(from) + m.mappings[normalizedFrom] = to + log.Debugf("amp model mapping registered: %s -> %s", from, to) + } } if len(m.mappings) > 0 { log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) } + if n := len(m.regexps); n > 0 { + log.Infof("amp model mapping: loaded %d regex mapping(s)", n) + } } // GetMappings returns a copy of current mappings (for debugging/status). @@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string { } return result } + +type regexMapping struct { + re *regexp.Regexp + to string +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index 664a17c50..1b36f2128 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) { t.Error("Original map was modified") } } + +func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{ + {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, + }) + defer reg.UnregisterClient("test-client-regex-1") + + mappings := []config.AmpModelMapping{ + {From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true}, + } + + mapper := NewModelMapper(mappings) + + // Incoming model has reasoning suffix but should match base via regex + result := mapper.MapModel("gpt-5(high)") + if result != "gemini-2.5-pro" { + t.Errorf("Expected gemini-2.5-pro, got %s", result) + } +} + +func TestModelMapper_Regex_ExactPrecedence(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{ + {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, + }) + defer reg.UnregisterClient("test-client-regex-2") + defer reg.UnregisterClient("test-client-regex-3") + + mappings := []config.AmpModelMapping{ + {From: "gpt-5", To: "claude-sonnet-4"}, // exact + {From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex + } + + mapper := NewModelMapper(mappings) + + // Exact match should win over regex + result := mapper.MapModel("gpt-5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) { + // Invalid regex should be skipped and not cause panic + mappings := []config.AmpModelMapping{ + {From: "(", To: "target", Regex: true}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("anything") + if result != "" { + t.Errorf("Expected empty result due to invalid regex, got %s", result) + } +} + +func TestModelMapper_Regex_CaseInsensitive(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client-regex-4") + + mappings := []config.AmpModelMapping{ + {From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index be68dcb97..4f50f494a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -156,6 +156,11 @@ type AmpModelMapping struct { // To is the target model name to route to (e.g., "claude-sonnet-4"). // The target model must have available providers in the registry. To string `yaml:"to" json:"to"` + + // Regex indicates whether the 'from' field should be interpreted as a regular + // expression for matching model names. When true, this mapping is evaluated + // after exact matches and in the order provided. Defaults to false (exact match). + Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"` } // AmpCode groups Amp CLI integration settings including upstream routing, @@ -401,7 +406,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.DisableCooling = false cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository - cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) + cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) if err = yaml.Unmarshal(data, &cfg); err != nil { if optional { // In cloud deploy mode, if YAML parsing fails, return empty config instead of error. diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index a4e020b1d..9bfef8adc 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "runtime/debug" + "strings" "time" "github.com/gin-gonic/gin" @@ -14,11 +15,24 @@ import ( log "github.com/sirupsen/logrus" ) +// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking. +var aiAPIPrefixes = []string{ + "/v1/chat/completions", + "/v1/completions", + "/v1/messages", + "/v1/responses", + "/v1beta/models/", + "/api/provider/", +} + const skipGinLogKey = "__gin_skip_request_logging__" // GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses // using logrus. It captures request details including method, path, status code, latency, -// client IP, and any error messages, formatting them in a Gin-style log format. +// client IP, and any error messages. Request ID is only added for AI API requests. +// +// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ... +// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ... // // Returns: // - gin.HandlerFunc: A middleware handler for request logging @@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc { path := c.Request.URL.Path raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery) + // Only generate request ID for AI API paths + var requestID string + if isAIAPIPath(path) { + requestID = GenerateRequestID() + SetGinRequestID(c, requestID) + ctx := WithRequestID(c.Request.Context(), requestID) + c.Request = c.Request.WithContext(ctx) + } + c.Next() if shouldSkipGinRequestLogging(c) { @@ -49,21 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc { clientIP := c.ClientIP() method := c.Request.Method errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String() - timestamp := time.Now().Format("2006/01/02 - 15:04:05") - logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path) + + logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path) if errorMessage != "" { logLine = logLine + " | " + errorMessage } + var entry *log.Entry + if requestID != "" { + entry = log.WithField("request_id", requestID) + } else { + entry = log.WithField("request_id", "--------") + } + switch { case statusCode >= http.StatusInternalServerError: - log.Error(logLine) + entry.Error(logLine) case statusCode >= http.StatusBadRequest: - log.Warn(logLine) + entry.Warn(logLine) default: - log.Info(logLine) + entry.Info(logLine) + } + } +} + +// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking. +func isAIAPIPath(path string) bool { + for _, prefix := range aiAPIPrefixes { + if strings.HasPrefix(path, prefix) { + return true } } + return false } // GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index 8cfef21ac..f9942b860 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -24,7 +24,8 @@ var ( ) // LogFormatter defines a custom log format for logrus. -// This formatter adds timestamp, level, and source location to each log entry. +// This formatter adds timestamp, level, request ID, and source location to each log entry. +// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2 type LogFormatter struct{} // Format renders a single log entry with custom formatting. @@ -38,16 +39,27 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { timestamp := entry.Time.Format("2006-01-02 15:04:05") message := strings.TrimRight(entry.Message, "\r\n") - - // Handle nil Caller (can happen with some log entries) + + reqID := "" + if id, ok := entry.Data["request_id"].(string); ok && id != "" { + reqID = id + } + callerFile := "unknown" callerLine := 0 if entry.Caller != nil { callerFile = filepath.Base(entry.Caller.File) callerLine = entry.Caller.Line } - - formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message) + + levelStr := fmt.Sprintf("%-5s", entry.Level.String()) + + var formatted string + if reqID != "" { + formatted = fmt.Sprintf("[%s] [%s] [%s:%d] | %s | %s\n", timestamp, levelStr, callerFile, callerLine, reqID, message) + } else { + formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, levelStr, callerFile, callerLine, message) + } buffer.WriteString(formatted) return buffer.Bytes(), nil diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index 391f28690..397a4a083 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -43,10 +43,11 @@ type RequestLogger interface { // - response: The raw response data // - apiRequest: The API request data // - apiResponse: The API response data + // - requestID: Optional request ID for log file naming // // Returns: // - error: An error if logging fails, nil otherwise - LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error + LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks. // @@ -55,11 +56,12 @@ type RequestLogger interface { // - method: The HTTP method // - headers: The request headers // - body: The request body + // - requestID: Optional request ID for log file naming // // Returns: // - StreamingLogWriter: A writer for streaming response chunks // - error: An error if logging initialization fails, nil otherwise - LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) + LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) // IsEnabled returns whether request logging is currently enabled. // @@ -177,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) { // - response: The raw response data // - apiRequest: The API request data // - apiResponse: The API response data +// - requestID: Optional request ID for log file naming // // Returns: // - error: An error if logging fails, nil otherwise -func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error { - return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false) +func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error { + return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID) } // LogRequestWithOptions logs a request with optional forced logging behavior. // The force flag allows writing error logs even when regular request logging is disabled. -func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error { - return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force) +func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error { + return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID) } -func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error { +func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error { if !l.enabled && !force { return nil } @@ -200,10 +203,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st return fmt.Errorf("failed to create logs directory: %w", errEnsure) } - // Generate filename - filename := l.generateFilename(url) + // Generate filename with request ID + filename := l.generateFilename(url, requestID) if force && !l.enabled { - filename = l.generateErrorFilename(url) + filename = l.generateErrorFilename(url, requestID) } filePath := filepath.Join(l.logsDir, filename) @@ -271,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st // - method: The HTTP method // - headers: The request headers // - body: The request body +// - requestID: Optional request ID for log file naming // // Returns: // - StreamingLogWriter: A writer for streaming response chunks // - error: An error if logging initialization fails, nil otherwise -func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) { +func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) { if !l.enabled { return &NoOpStreamingLogWriter{}, nil } @@ -285,8 +289,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ return nil, fmt.Errorf("failed to create logs directory: %w", err) } - // Generate filename - filename := l.generateFilename(url) + // Generate filename with request ID + filename := l.generateFilename(url, requestID) filePath := filepath.Join(l.logsDir, filename) requestHeaders := make(map[string][]string, len(headers)) @@ -330,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ } // generateErrorFilename creates a filename with an error prefix to differentiate forced error logs. -func (l *FileRequestLogger) generateErrorFilename(url string) string { - return fmt.Sprintf("error-%s", l.generateFilename(url)) +func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string { + return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...)) } // ensureLogsDir creates the logs directory if it doesn't exist. @@ -346,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error { } // generateFilename creates a sanitized filename from the URL path and current timestamp. +// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log // // Parameters: // - url: The request URL +// - requestID: Optional request ID to include in filename // // Returns: // - string: A sanitized filename for the log file -func (l *FileRequestLogger) generateFilename(url string) string { +func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string { // Extract path from URL path := url if strings.Contains(url, "?") { @@ -368,12 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string { sanitized := l.sanitizeForFilename(path) // Add timestamp - timestamp := time.Now().Format("2006-01-02T150405-.000000000") - timestamp = strings.Replace(timestamp, ".", "", -1) + timestamp := time.Now().Format("2006-01-02T150405") - id := requestLogID.Add(1) + // Use request ID if provided, otherwise use sequential ID + var idPart string + if len(requestID) > 0 && requestID[0] != "" { + idPart = requestID[0] + } else { + id := requestLogID.Add(1) + idPart = fmt.Sprintf("%d", id) + } - return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id) + return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart) } // sanitizeForFilename replaces characters that are not safe for filenames. diff --git a/internal/logging/requestid.go b/internal/logging/requestid.go new file mode 100644 index 000000000..8bd045d11 --- /dev/null +++ b/internal/logging/requestid.go @@ -0,0 +1,61 @@ +package logging + +import ( + "context" + "crypto/rand" + "encoding/hex" + + "github.com/gin-gonic/gin" +) + +// requestIDKey is the context key for storing/retrieving request IDs. +type requestIDKey struct{} + +// ginRequestIDKey is the Gin context key for request IDs. +const ginRequestIDKey = "__request_id__" + +// GenerateRequestID creates a new 8-character hex request ID. +func GenerateRequestID() string { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + return "00000000" + } + return hex.EncodeToString(b) +} + +// WithRequestID returns a new context with the request ID attached. +func WithRequestID(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, requestIDKey{}, requestID) +} + +// GetRequestID retrieves the request ID from the context. +// Returns empty string if not found. +func GetRequestID(ctx context.Context) string { + if ctx == nil { + return "" + } + if id, ok := ctx.Value(requestIDKey{}).(string); ok { + return id + } + return "" +} + +// SetGinRequestID stores the request ID in the Gin context. +func SetGinRequestID(c *gin.Context, requestID string) { + if c != nil { + c.Set(ginRequestIDKey, requestID) + } +} + +// GetGinRequestID retrieves the request ID from the Gin context. +func GetGinRequestID(c *gin.Context) string { + if c == nil { + return "" + } + if id, exists := c.Get(ginRequestIDKey); exists { + if s, ok := id.(string); ok { + return s + } + } + return "" +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 20f855856..960e55ace 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -217,13 +218,39 @@ func (h *BaseAPIHandler) GetAlt(c *gin.Context) string { // Parameters: // - handler: The API handler associated with the request. // - c: The Gin context of the current request. -// - ctx: The parent context. +// - ctx: The parent context (caller values/deadlines are preserved; request context adds cancellation and request ID). // // Returns: // - context.Context: The new context with cancellation and embedded values. // - APIHandlerCancelFunc: A function to cancel the context and log the response. func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) { - newCtx, cancel := context.WithCancel(ctx) + parentCtx := ctx + if parentCtx == nil { + parentCtx = context.Background() + } + + var requestCtx context.Context + if c != nil && c.Request != nil { + requestCtx = c.Request.Context() + } + + if requestCtx != nil && logging.GetRequestID(parentCtx) == "" { + if requestID := logging.GetRequestID(requestCtx); requestID != "" { + parentCtx = logging.WithRequestID(parentCtx, requestID) + } else if requestID := logging.GetGinRequestID(c); requestID != "" { + parentCtx = logging.WithRequestID(parentCtx, requestID) + } + } + newCtx, cancel := context.WithCancel(parentCtx) + if requestCtx != nil && requestCtx != parentCtx { + go func() { + select { + case <-requestCtx.Done(): + cancel() + case <-newCtx.Done(): + } + }() + } newCtx = context.WithValue(newCtx, "gin", c) newCtx = context.WithValue(newCtx, "handler", handler) return newCtx, func(params ...interface{}) { diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index b3876de66..e2d4bbf10 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -12,6 +12,7 @@ import ( "time" "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -389,17 +390,18 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req accountType, accountInfo := auth.AccountInfo() proxyInfo := auth.ProxyInfo() + entry := logEntryWithRequestID(ctx) if accountType == "api_key" { if proxyInfo != "" { - log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo) + entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo) } else { - log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) + entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) } } else if accountType == "oauth" { if proxyInfo != "" { - log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo) + entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo) } else { - log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) + entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) } } @@ -449,17 +451,18 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, accountType, accountInfo := auth.AccountInfo() proxyInfo := auth.ProxyInfo() + entry := logEntryWithRequestID(ctx) if accountType == "api_key" { if proxyInfo != "" { - log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo) + entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo) } else { - log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) + entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) } } else if accountType == "oauth" { if proxyInfo != "" { - log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo) + entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo) } else { - log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) + entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) } } @@ -509,17 +512,18 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string accountType, accountInfo := auth.AccountInfo() proxyInfo := auth.ProxyInfo() + entry := logEntryWithRequestID(ctx) if accountType == "api_key" { if proxyInfo != "" { - log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo) + entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo) } else { - log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) + entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) } } else if accountType == "oauth" { if proxyInfo != "" { - log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo) + entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo) } else { - log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) + entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) } } @@ -1606,6 +1610,17 @@ type RequestPreparer interface { PrepareRequest(req *http.Request, auth *Auth) error } +// logEntryWithRequestID returns a logrus entry with request_id field if available in context. +func logEntryWithRequestID(ctx context.Context) *log.Entry { + if ctx == nil { + return log.NewEntry(log.StandardLogger()) + } + if reqID := logging.GetRequestID(ctx); reqID != "" { + return log.WithField("request_id", reqID) + } + return log.NewEntry(log.StandardLogger()) +} + // InjectCredentials delegates per-provider HTTP request preparation when supported. // If the registered executor for the auth provider implements RequestPreparer, // it will be invoked to modify the request (e.g., add headers).