diff --git a/.github/workflows/release-pipeline.yml b/.github/workflows/release-pipeline.yml index 1fa93c820..5f521e16f 100644 --- a/.github/workflows/release-pipeline.yml +++ b/.github/workflows/release-pipeline.yml @@ -3,7 +3,7 @@ name: Release Pipeline # Triggers automatically on push to main when any version file changes on: push: - branches: ["main"] + branches: ["main", "v1.4.0"] # Prevent concurrent runs concurrency: @@ -594,7 +594,7 @@ jobs: fi # Build the message with proper formatting - MESSAGE=$(printf "🚀 **Release Pipeline Complete**\n\n**Components:**\n• Core: %s\n• Framework: %s\n• Plugins: %s\n• Bifrost HTTP: %s\n\n**Details:**\n• Branch: \`main\`\n• Commit: \`%.8s\`\n• Author: %s\n\n[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" "$CORE_STATUS" "$FRAMEWORK_STATUS" "$PLUGINS_STATUS" "$BIFROST_STATUS" "${{ github.sha }}" "${{ github.actor }}") + MESSAGE=$(printf "🚀 **Release Pipeline Complete**\n\n**Components:**\n• Core: %s\n• Framework: %s\n• Plugins: %s\n• Bifrost HTTP: %s\n\n**Details:**\n• Branch: \`${{ github.ref_name }}\`\n• Commit: \`%.8s\`\n• Author: %s\n\n[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" "$CORE_STATUS" "$FRAMEWORK_STATUS" "$PLUGINS_STATUS" "$BIFROST_STATUS" "${{ github.sha }}" "${{ github.actor }}") payload="$(jq -n --arg content "$MESSAGE" '{content:$content}')" curl -sS -H "Content-Type: application/json" -d "$payload" "$DISCORD_WEBHOOK" diff --git a/.github/workflows/scripts/push-mintlify-changelog.sh b/.github/workflows/scripts/push-mintlify-changelog.sh index cb322ef42..26184f562 100755 --- a/.github/workflows/scripts/push-mintlify-changelog.sh +++ b/.github/workflows/scripts/push-mintlify-changelog.sh @@ -236,7 +236,12 @@ if ! grep -q "\"$route\"" docs/docs.json; then fi # Pulling again before committing -git pull origin main +CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" +fi +git pull origin "$CURRENT_BRANCH" # Commit and push changes git add docs/changelogs/$VERSION.mdx git add docs/docs.json @@ -247,4 +252,4 @@ done git config user.name "github-actions[bot]" git config user.email "41898282+github-actions[bot]@users.noreply.github.com" git commit -m "Adds changelog for $VERSION --skip-pipeline" -git push origin main +git push origin "$CURRENT_BRANCH" diff --git a/.github/workflows/scripts/release-bifrost-http.sh b/.github/workflows/scripts/release-bifrost-http.sh index b469b1107..1244ee821 100755 --- a/.github/workflows/scripts/release-bifrost-http.sh +++ b/.github/workflows/scripts/release-bifrost-http.sh @@ -358,7 +358,12 @@ echo "✅ Transport build validation successful" # Commit and push changes if any # First, pull latest changes to avoid conflicts -git pull origin main +CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" +fi +git pull origin "$CURRENT_BRANCH" # Stage any changes made to transports/ git add transports/ diff --git a/.github/workflows/scripts/release-framework.sh b/.github/workflows/scripts/release-framework.sh index 2a4f88c7c..e9f469adf 100755 --- a/.github/workflows/scripts/release-framework.sh +++ b/.github/workflows/scripts/release-framework.sh @@ -111,6 +111,10 @@ if ! git diff --cached --quiet; then git commit -m "framework: bump core to $CORE_VERSION --skip-pipeline" # Push the bump so go.mod/go.sum changes are recorded on the branch CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" + if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" + fi git push origin "$CURRENT_BRANCH" echo "🔧 Pushed framework bump to $CURRENT_BRANCH" else diff --git a/core/bifrost.go b/core/bifrost.go index 4d8a021b8..73af3c347 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -67,6 +67,7 @@ type Bifrost struct { pluginPipelinePool sync.Pool // Pool for PluginPipeline objects bifrostRequestPool sync.Pool // Pool for BifrostRequest objects logger schemas.Logger // logger instance, default logger is used if not provided + tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured) mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured) mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. @@ -77,12 +78,32 @@ type Bifrost struct { type PluginPipeline struct { plugins []schemas.Plugin logger schemas.Logger + tracer schemas.Tracer // Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order) executedPreHooks int // Errors from PreHooks and PostHooks preHookErrors []error postHookErrors []error + + // Streaming post-hook timing accumulation (for aggregated spans) + postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name + postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) + chunkCount int +} + +// pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks +type pluginTimingAccumulator struct { + totalDuration time.Duration + invocations int + errors int +} + +// tracerWrapper wraps a Tracer to ensure atomic.Value stores consistent types. +// This is necessary because atomic.Value.Store() panics if called with values +// of different concrete types, even if they implement the same interface. +type tracerWrapper struct { + tracer schemas.Tracer } // Global logger instance which is set in the Init function @@ -104,6 +125,13 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { } providerUtils.SetLogger(config.Logger) + + // Initialize tracer (use NoOpTracer if not provided) + tracer := config.Tracer + if tracer == nil { + tracer = schemas.DefaultTracer() + } + bifrostCtx, cancel := context.WithCancel(ctx) bifrost := &Bifrost{ ctx: bifrostCtx, @@ -115,6 +143,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { keySelector: config.KeySelector, logger: config.Logger, } + bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) bifrost.plugins.Store(&config.Plugins) // Initialize providers slice @@ -221,6 +250,20 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { return bifrost, nil } +// SetTracer sets the tracer for the Bifrost instance. +func (bifrost *Bifrost) SetTracer(tracer schemas.Tracer) { + if tracer == nil { + // Fall back to no-op tracer if not provided + tracer = schemas.DefaultTracer() + } + bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) +} + +// getTracer returns the tracer from atomic storage with type assertion. +func (bifrost *Bifrost) getTracer() schemas.Tracer { + return bifrost.tracer.Load().(*tracerWrapper).tracer +} + // ReloadConfig reloads the config from DB // Currently we only update account and drop excess requests // We will keep on adding other aspects as required @@ -318,7 +361,7 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr response, bifrostErr := executeRequestWithRetries(&ctx, config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return provider.ListModels(ctx, keys, request) - }, schemas.ListModelsRequest, req.Provider, "") + }, schemas.ListModelsRequest, req.Provider, "", bifrost.getTracer(), nil) if bifrostErr != nil { bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ RequestType: schemas.ListModelsRequest, @@ -1971,7 +2014,6 @@ func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecution return nil } - // PROVIDER MANAGEMENT // createBaseProvider creates a provider based on the base provider type @@ -2351,9 +2393,19 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR bifrost.logger.Debug(fmt.Sprintf("trying fallback provider %s with model %s", fallback.Provider, fallback.Model)) ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) + // Start span for fallback attempt + tracer := bifrost.getTracer() + spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback) + tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider)) + tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model) + tracer.SetAttribute(handle, "fallback.index", i+1) + ctx = spanCtx + fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { bifrost.logger.Debug(fmt.Sprintf("fallback provider %s with model %s is nil", fallback.Provider, fallback.Model)) + tracer.SetAttribute(handle, "error", "fallback request preparation failed") + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed") continue } @@ -2361,9 +2413,16 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR result, fallbackErr := bifrost.tryRequest(ctx, fallbackReq) if fallbackErr == nil { bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + tracer.EndSpan(handle, schemas.SpanStatusOk, "") return result, nil } + // End span with error status + if fallbackErr.Error != nil { + tracer.SetAttribute(handle, "error", fallbackErr.Error.Message) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed") + // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -2433,8 +2492,18 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackIndex, i+1) ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) + // Start span for fallback attempt + tracer := bifrost.getTracer() + spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback) + tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider)) + tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model) + tracer.SetAttribute(handle, "fallback.index", i+1) + ctx = spanCtx + fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { + tracer.SetAttribute(handle, "error", "fallback request preparation failed") + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed") continue } @@ -2442,9 +2511,16 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi result, fallbackErr := bifrost.tryStreamRequest(ctx, fallbackReq) if fallbackErr == nil { bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + tracer.EndSpan(handle, schemas.SpanStatusOk, "") return result, nil } + // End span with error status + if fallbackErr.Error != nil { + tracer.SetAttribute(handle, "error", fallbackErr.Error.Message) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed") + // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -2771,6 +2847,8 @@ func executeRequestWithRetries[T any]( requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, + tracer schemas.Tracer, + req *schemas.BifrostRequest, ) (T, *schemas.BifrostError) { var result T var bifrostError *schemas.BifrostError @@ -2799,9 +2877,76 @@ func executeRequestWithRetries[T any]( logger.Debug("attempting %s request for provider %s", requestType, providerKey) + // Start span for LLM call (or retry attempt) + var spanName string + var spanKind schemas.SpanKind + if attempts > 0 { + spanName = fmt.Sprintf("retry.attempt.%d", attempts) + spanKind = schemas.SpanKindRetry + } else { + spanName = "llm.call" + spanKind = schemas.SpanKindLLMCall + } + spanCtx, handle := tracer.StartSpan(*ctx, spanName, spanKind) + tracer.SetAttribute(handle, schemas.AttrProviderName, string(providerKey)) + tracer.SetAttribute(handle, schemas.AttrRequestModel, model) + tracer.SetAttribute(handle, "request.type", string(requestType)) + if attempts > 0 { + tracer.SetAttribute(handle, "retry.count", attempts) + } + + // Populate LLM request attributes (messages, parameters, etc.) + if req != nil { + tracer.PopulateLLMRequestAttributes(handle, req) + } + + // Update context with span ID + *ctx = spanCtx + + // Store tracer in context BEFORE calling requestHandler, so streaming goroutines + // have access to it for completing deferred spans when the stream ends. + // The streaming goroutine captures the context when it starts, so these values + // must be set before requestHandler() is called. + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyTracer, tracer) + + // Record stream start time for TTFT calculation (only for streaming requests) + // This is also used by RunPostHooks to detect streaming mode + if IsStreamRequestType(requestType) { + streamStartTime := time.Now() + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyStreamStartTime, streamStartTime) + } + // Attempt the request result, bifrostError = requestHandler() + // Check if result is a streaming channel - if so, defer span completion + if _, isStreamChan := any(result).(chan *schemas.BifrostStream); isStreamChan { + // For streaming requests, store the span handle in TraceStore keyed by trace ID + // This allows the provider's streaming goroutine to retrieve it later + if traceID, ok := (*ctx).Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.StoreDeferredSpan(traceID, handle) + } + // Don't end the span here - it will be ended when streaming completes + } else { + // Populate LLM response attributes for non-streaming responses + if resp, ok := any(result).(*schemas.BifrostResponse); ok { + tracer.PopulateLLMResponseAttributes(handle, resp, bifrostError) + } + + // End span with appropriate status + if bifrostError != nil { + if bifrostError.Error != nil { + tracer.SetAttribute(handle, "error", bifrostError.Error.Message) + } + if bifrostError.StatusCode != nil { + tracer.SetAttribute(handle, "status_code", *bifrostError.StatusCode) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "request failed") + } else { + tracer.EndSpan(handle, schemas.SpanStatusOk, "") + } + } + logger.Debug("request %s for provider %s completed", requestType, providerKey) // Check if successful or if we should retry @@ -2897,8 +3042,16 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } else { // Use the custom provider name for actual key selection, but pass base provider type for key validation + // Start span for key selection + keyTracer := bifrost.getTracer() + keySpanCtx, keyHandle := keyTracer.StartSpan(req.Context, "key.selection", schemas.SpanKindInternal) + keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(provider.GetProviderKey())) + keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) + key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) if err != nil { + keyTracer.SetAttribute(keyHandle, "error", err.Error()) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) bifrost.logger.Debug("error selecting key for model %s: %v", model, err) req.Err <- schemas.BifrostError{ IsBifrostError: false, @@ -2914,6 +3067,11 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } continue } + keyTracer.SetAttribute(keyHandle, "key.id", key.ID) + keyTracer.SetAttribute(keyHandle, "key.name", key.Name) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") + // Update context with span ID for subsequent operations + req.Context = keySpanCtx req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyID, key.ID) req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyName, key.Name) } @@ -2930,20 +3088,32 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } return resp, nil } + // Store a finalizer callback to create aggregated post-hook spans at stream end + // This closure captures the pipeline reference and releases it after finalization + postHookSpanFinalizer := func(ctx context.Context) { + pipeline.FinalizeStreamingPostHookSpans(ctx) + // Release the pipeline AFTER finalizing spans (not before streaming completes) + bifrost.releasePluginPipeline(pipeline) + } + req.Context = context.WithValue(req.Context, schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) } // Execute request with retries + reqTracer := bifrost.getTracer() if IsStreamRequestType(req.RequestType) { stream, bifrostError = executeRequestWithRetries(&req.Context, config, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner) - }, req.RequestType, provider.GetProviderKey(), model) + }, req.RequestType, provider.GetProviderKey(), model, reqTracer, &req.BifrostRequest) } else { result, bifrostError = executeRequestWithRetries(&req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) { return bifrost.handleProviderRequest(provider, req, key, keys) - }, req.RequestType, provider.GetProviderKey(), model) + }, req.RequestType, provider.GetProviderKey(), model, reqTracer, &req.BifrostRequest) } - if pipeline != nil { + // Release pipeline immediately for non-streaming requests only + // For streaming, the pipeline is released in the postHookSpanFinalizer after streaming completes + // Exception: if streaming request has an error, release immediately since finalizer won't be called + if pipeline != nil && (!IsStreamRequestType(req.RequestType) || bifrostError != nil) { bifrost.releasePluginPipeline(pipeline) } @@ -3162,12 +3332,33 @@ func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostR *ctx = pluginCtx.GetParentCtxWithUserValues() }() for i, plugin := range p.plugins { - p.logger.Debug("running pre-hook for plugin %s", plugin.GetName()) + pluginName := plugin.GetName() + p.logger.Debug("running pre-hook for plugin %s", pluginName) + + // Start span for this plugin's PreHook + spanCtx, handle := p.tracer.StartSpan(pluginCtx, fmt.Sprintf("plugin.%s.prehook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + // Update pluginCtx with span context for nested operations + if spanCtx != nil { + if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { + pluginCtx.SetValue(schemas.BifrostContextKeySpanID, spanID) + } + } + req, shortCircuit, err = plugin.PreHook(pluginCtx, req) + + // End span with appropriate status if err != nil { + p.tracer.SetAttribute(handle, "error", err.Error()) + p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) p.preHookErrors = append(p.preHookErrors, err) - p.logger.Warn("error in PreHook for plugin %s: %v", plugin.GetName(), err) + p.logger.Warn("error in PreHook for plugin %s: %v", pluginName, err) + } else if shortCircuit != nil { + p.tracer.SetAttribute(handle, "short_circuit", true) + p.tracer.EndSpan(handle, schemas.SpanStatusOk, "short-circuit") + } else { + p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") } + p.executedPreHooks = i + 1 if shortCircuit != nil { return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran @@ -3180,6 +3371,7 @@ func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostR // Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). // Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil. // runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0 +// For streaming requests, it accumulates timing per plugin instead of creating individual spans per chunk. func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostResponse, *schemas.BifrostError) { // Defensive: ensure count is within valid bounds if runFrom < 0 { @@ -3188,20 +3380,60 @@ func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.Bifros if runFrom > len(p.plugins) { runFrom = len(p.plugins) } + + // Detect streaming mode - if StreamStartTime is set, we're in a streaming context + isStreaming := (*ctx).Value(schemas.BifrostContextKeyStreamStartTime) != nil + var err error pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(*ctx, 10*time.Second) defer cancel() for i := runFrom - 1; i >= 0; i-- { plugin := p.plugins[i] - p.logger.Debug("running post-hook for plugin %s", plugin.GetName()) - resp, bifrostErr, err = plugin.PostHook(pluginCtx, resp, bifrostErr) - if err != nil { - p.postHookErrors = append(p.postHookErrors, err) - p.logger.Warn("error in PostHook for plugin %s: %v", plugin.GetName(), err) + pluginName := plugin.GetName() + p.logger.Debug("running post-hook for plugin %s", pluginName) + + if isStreaming { + // For streaming: accumulate timing, don't create individual spans per chunk + start := time.Now() + resp, bifrostErr, err = plugin.PostHook(pluginCtx, resp, bifrostErr) + duration := time.Since(start) + + p.accumulatePluginTiming(pluginName, duration, err != nil) + if err != nil { + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn("error in PostHook for plugin %s: %v", pluginName, err) + } + } else { + // For non-streaming: create span per plugin (existing behavior) + spanCtx, handle := p.tracer.StartSpan(pluginCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + // Update pluginCtx with span context for nested operations + if spanCtx != nil { + if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { + pluginCtx.SetValue(schemas.BifrostContextKeySpanID, spanID) + } + } + + resp, bifrostErr, err = plugin.PostHook(pluginCtx, resp, bifrostErr) + + // End span with appropriate status + if err != nil { + p.tracer.SetAttribute(handle, "error", err.Error()) + p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn("error in PostHook for plugin %s: %v", pluginName, err) + } else { + p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") + } } // If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that // If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that } + + // Increment chunk count for streaming + if isStreaming { + p.chunkCount++ + } + // Capturing plugin ctx values and putting them in the request context *ctx = pluginCtx.GetParentCtxWithUserValues() // Final logic: if both are set, error takes precedence, unless error is nil @@ -3221,6 +3453,91 @@ func (p *PluginPipeline) resetPluginPipeline() { p.executedPreHooks = 0 p.preHookErrors = p.preHookErrors[:0] p.postHookErrors = p.postHookErrors[:0] + // Reset streaming timing accumulation + p.chunkCount = 0 + if p.postHookTimings != nil { + clear(p.postHookTimings) + } + p.postHookPluginOrder = p.postHookPluginOrder[:0] +} + +// accumulatePluginTiming accumulates timing for a plugin during streaming +func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) { + if p.postHookTimings == nil { + p.postHookTimings = make(map[string]*pluginTimingAccumulator) + } + timing, ok := p.postHookTimings[pluginName] + if !ok { + timing = &pluginTimingAccumulator{} + p.postHookTimings[pluginName] = timing + // Track order on first occurrence (first chunk) + p.postHookPluginOrder = append(p.postHookPluginOrder, pluginName) + } + timing.totalDuration += duration + timing.invocations++ + if hasError { + timing.errors++ + } +} + +// FinalizeStreamingPostHookSpans creates aggregated spans for each plugin after streaming completes. +// This should be called once at the end of streaming to create one span per plugin with average timing. +// Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one). +func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { + if p.postHookTimings == nil || len(p.postHookTimings) == 0 || len(p.postHookPluginOrder) == 0 { + return + } + + // Collect handles and timing info to end spans in reverse order + type spanInfo struct { + handle schemas.SpanHandle + hasErrors bool + } + spans := make([]spanInfo, 0, len(p.postHookPluginOrder)) + currentCtx := ctx + + // Start spans in execution order (nested: each is a child of the previous) + for _, pluginName := range p.postHookPluginOrder { + timing, ok := p.postHookTimings[pluginName] + if !ok || timing.invocations == 0 { + continue + } + + // Create span as child of the previous span (nested hierarchy) + newCtx, handle := p.tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + if handle == nil { + continue + } + + // Calculate average duration in milliseconds + avgMs := float64(timing.totalDuration.Milliseconds()) / float64(timing.invocations) + + // Set aggregated attributes + p.tracer.SetAttribute(handle, schemas.AttrPluginInvocations, timing.invocations) + p.tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) + p.tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, timing.totalDuration.Milliseconds()) + + if timing.errors > 0 { + p.tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, timing.errors) + } + + spans = append(spans, spanInfo{handle: handle, hasErrors: timing.errors > 0}) + currentCtx = newCtx + } + + // End spans in reverse order (innermost first, like unwinding a call stack) + for i := len(spans) - 1; i >= 0; i-- { + if spans[i].hasErrors { + p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") + } else { + p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") + } + } +} + +// GetChunkCount returns the number of chunks processed during streaming +func (p *PluginPipeline) GetChunkCount() int { + return p.chunkCount } // getPluginPipeline gets a PluginPipeline from the pool and configures it @@ -3228,6 +3545,7 @@ func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { pipeline := bifrost.pluginPipelinePool.Get().(*PluginPipeline) pipeline.plugins = *bifrost.plugins.Load() pipeline.logger = bifrost.logger + pipeline.tracer = bifrost.getTracer() return pipeline } @@ -3625,6 +3943,11 @@ func (bifrost *Bifrost) Shutdown() { } } + // Stop the tracerWrapper to clean up background goroutines + if tracerWrapper := bifrost.tracer.Load().(*tracerWrapper); tracerWrapper != nil && tracerWrapper.tracer != nil { + tracerWrapper.tracer.Stop() + } + // Cleanup plugins for _, plugin := range *bifrost.plugins.Load() { err := plugin.Cleanup() diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 6b642d632..7d6d29fe0 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -68,6 +68,8 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + schemas.DefaultTracer(), + nil, ) if callCount != 1 { @@ -101,6 +103,8 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + schemas.DefaultTracer(), + nil, ) if callCount != 3 { @@ -134,6 +138,8 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + schemas.DefaultTracer(), + nil, ) // Should try: initial + 2 retries = 3 total attempts @@ -196,6 +202,8 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + schemas.DefaultTracer(), + nil, ) if callCount != 1 { @@ -268,6 +276,8 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + schemas.DefaultTracer(), + nil, ) // Should try: initial + 1 retry = 2 total attempts @@ -496,6 +506,8 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + schemas.DefaultTracer(), + nil, ) // Verify call progression diff --git a/core/changelog.md b/core/changelog.md index b717cf5e0..1335fa236 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,3 +1,35 @@ - feat: added code mode to mcp - feat: added health monitoring to mcp -- feat: added responses format tool execution support to mcp \ No newline at end of file +- feat: added responses format tool execution support to mcp +- feat: adds central tracer for e2e tracing + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor removed, replaced with HTTPTransportMiddleware** + + The `TransportInterceptor` method has been removed from the `Plugin` interface in `schemas/plugin.go`. All plugins must now implement `HTTPTransportMiddleware()` instead. + + **Old API (removed in core v1.3.0):** + ```go + TransportInterceptor(ctx *BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + ``` + + **New API (core v1.3.0+):** + ```go + HTTPTransportMiddleware() BifrostHTTPMiddleware + // where BifrostHTTPMiddleware = func(next fasthttp.RequestHandler) fasthttp.RequestHandler + ``` + + **Key changes:** + - Method renamed: `TransportInterceptor` -> `HTTPTransportMiddleware` + - Return type changed: Now returns a middleware function instead of modified headers/body + - New import required: `github.com/valyala/fasthttp` + - Flow control: Must call `next(ctx)` explicitly to continue the middleware chain + - New capability: Can now intercept and modify responses (not just requests) + + **Migration for plugin consumers:** + 1. Update your plugin to implement `HTTPTransportMiddleware()` instead of `TransportInterceptor()` + 2. If your plugin doesn't need HTTP transport interception, return `nil` from `HTTPTransportMiddleware()` + 3. Update tests to verify the new middleware signature + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for complete instructions and code examples. \ No newline at end of file diff --git a/core/go.mod b/core/go.mod index 724375d92..a9fc3c141 100644 --- a/core/go.mod +++ b/core/go.mod @@ -47,7 +47,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect - github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/core/go.sum b/core/go.sum index 656faafb2..759671d03 100644 --- a/core/go.sum +++ b/core/go.sum @@ -72,8 +72,7 @@ github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5Nq github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= -github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= diff --git a/core/mcp.go b/core/mcp.go index b409cf49e..be9fb580b 100644 --- a/core/mcp.go +++ b/core/mcp.go @@ -1135,7 +1135,6 @@ func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) ( if config.InProcessServer == nil { return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") } - // Create in-process client directly connected to the provided server inProcessClient, err := client.NewInProcessClient(config.InProcessServer) if err != nil { diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 847e150c8..6ac7d8035 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -877,16 +877,30 @@ func SendInProgressEventResponsesChunk(ctx context.Context, postHookRunner schem // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. +// It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true). func ProcessAndSendResponse( ctx context.Context, postHookRunner schemas.PostHookRunner, response *schemas.BifrostResponse, responseChan chan *schemas.BifrostStream, ) { - // Run post hooks on the response + // Accumulate chunk for tracing (common for all providers) + if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.AddStreamingChunk(traceID, response) + } + } + + // Run post hooks on the response first so span reflects post-processed data processedResponse, processedError := postHookRunner(&ctx, response, nil) if HandleStreamControlSkip(processedError) { + // Even if skipping, complete the deferred span if this is the final chunk + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(&ctx, processedResponse, processedError) + } + } return } @@ -907,12 +921,20 @@ func ProcessAndSendResponse( case <-ctx.Done(): return } + + // Check if this is the final chunk and complete deferred span with post-processed data + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(&ctx, processedResponse, processedError) + } + } } // ProcessAndSendBifrostError handles post-hook processing and sends the bifrost error to the channel. // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. +// It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true). func ProcessAndSendBifrostError( ctx context.Context, postHookRunner schemas.PostHookRunner, @@ -920,10 +942,16 @@ func ProcessAndSendBifrostError( responseChan chan *schemas.BifrostStream, logger schemas.Logger, ) { - // Send scanner error through channel + // Run post hooks first so span reflects post-processed data processedResponse, processedError := postHookRunner(&ctx, nil, bifrostErr) if HandleStreamControlSkip(processedError) { + // Even if skipping, complete the deferred span if this is the final chunk + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(&ctx, processedResponse, processedError) + } + } return } @@ -943,6 +971,13 @@ func ProcessAndSendBifrostError( case responseChan <- streamResponse: case <-ctx.Done(): } + + // Check if this is the final chunk and complete deferred span with post-processed data + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(&ctx, processedResponse, processedError) + } + } } // ProcessAndSendError handles post-hook processing and sends the error to the channel. @@ -1385,3 +1420,94 @@ func GetBudgetTokensFromReasoningEffort( return budget, nil } + +// completeDeferredSpan completes the deferred LLM span for streaming requests. +// This is called when the final chunk is processed (when StreamEndIndicator is true). +// It retrieves the deferred span handle from TraceStore using the trace ID from context, +// populates response attributes from accumulated chunks, and ends the span. +func completeDeferredSpan(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) { + if ctx == nil { + return + } + + // Get the trace ID from context (this IS available in the provider's goroutine) + traceID, ok := (*ctx).Value(schemas.BifrostContextKeyTraceID).(string) + if !ok || traceID == "" { + return + } + + // Get the tracer from context + tracerVal := (*ctx).Value(schemas.BifrostContextKeyTracer) + if tracerVal == nil { + return + } + tracer, ok := tracerVal.(schemas.Tracer) + if !ok || tracer == nil { + return + } + + // Get the deferred span handle from TraceStore using trace ID + handle := tracer.GetDeferredSpanHandle(traceID) + if handle == nil { + return + } + + // Set total latency from the final chunk + if result != nil { + extraFields := result.GetExtraFields() + if extraFields.Latency > 0 { + tracer.SetAttribute(handle, "gen_ai.response.total_latency_ms", extraFields.Latency) + } + } + + // Get accumulated response with full data (content, tool calls, reasoning, etc.) + // This builds a complete BifrostResponse from all the streaming chunks + accumulatedResp, ttftMs, chunkCount := tracer.GetAccumulatedChunks(traceID) + if accumulatedResp != nil { + // Use accumulated response for attributes (includes full content, tool calls, etc.) + tracer.PopulateLLMResponseAttributes(handle, accumulatedResp, err) + + // Set Time to First Token (TTFT) attribute + if ttftMs > 0 { + tracer.SetAttribute(handle, schemas.AttrTimeToFirstToken, ttftMs) + } + + // Set total chunks attribute + if chunkCount > 0 { + tracer.SetAttribute(handle, schemas.AttrTotalChunks, chunkCount) + } + } else if result != nil { + // Fall back to final chunk if no accumulated data (shouldn't happen normally) + tracer.PopulateLLMResponseAttributes(handle, result, err) + } + + // Finalize aggregated post-hook spans before ending the LLM span + // This creates one span per plugin with average execution time + // We need to set the llm.call span ID in context so post-hook spans become its children + if finalizer, ok := (*ctx).Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { + // Get the deferred span ID (the llm.call span) to set as parent for post-hook spans + spanID := tracer.GetDeferredSpanID(traceID) + if spanID != "" { + finalizerCtx := context.WithValue(*ctx, schemas.BifrostContextKeySpanID, spanID) + finalizer(finalizerCtx) + } else { + finalizer(*ctx) + } + } + + // End span with appropriate status + if err != nil { + if err.Error != nil { + tracer.SetAttribute(handle, "error", err.Error.Message) + } + if err.StatusCode != nil { + tracer.SetAttribute(handle, "status_code", *err.StatusCode) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "streaming request failed") + } else { + tracer.EndSpan(handle, schemas.SpanStatusOk, "") + } + + // Clear the deferred span from TraceStore + tracer.ClearDeferredSpan(traceID) +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 6646be966..94a12ab2b 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -22,6 +22,7 @@ type BifrostConfig struct { Account Account Plugins []Plugin Logger Logger + Tracer Tracer // Tracer for distributed tracing (nil = NoOpTracer) InitialPoolSize int // Initial pool size for sync pools in Bifrost. Higher values will reduce memory allocations but will increase memory usage. DropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. MCPConfig *MCPConfig // MCP (Model Context Protocol) configuration for tool integration @@ -140,6 +141,13 @@ const ( BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) BifrostContextKeyStructuredOutputToolName BifrostContextKey = "bifrost-structured-output-tool-name" // string (to store the name of the structured output tool (set by bifrost)) BifrostContextKeyUserAgent BifrostContextKey = "bifrost-user-agent" // string (set by bifrost) + BifrostContextKeyTraceID BifrostContextKey = "bifrost-trace-id" // string (trace ID for distributed tracing - set by tracing middleware) + BifrostContextKeySpanID BifrostContextKey = "bifrost-span-id" // string (current span ID for child span creation - set by tracer) + BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost) + BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) + BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) + BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) + BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) ) // NOTE: for custom plugin implementation dealing with streaming short circuit, diff --git a/core/schemas/context.go b/core/schemas/context.go index bcfae8108..4af3a6370 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -23,6 +23,7 @@ var reservedKeys = []any{ BifrostContextKeySkipKeySelection, BifrostContextKeyExtraHeaders, BifrostContextKeyURLPath, + BifrostContextKeyDeferTraceCompletion, } // BifrostContext is a custom context.Context implementation that tracks user-set values. diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index 2b7d5be93..6aa54a176 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -1,6 +1,12 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas +import ( + "context" + + "github.com/valyala/fasthttp" +) + // PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. // It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit). type PluginShortCircuit struct { @@ -27,6 +33,10 @@ type PluginStatus struct { Logs []string `json:"logs"` } +// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport +// It follows the standard pattern: receives the next handler and returns a new handler +type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler + // Plugin defines the interface for Bifrost plugins. // Plugins can intercept and modify requests and responses at different stages // of the processing pipeline. @@ -35,7 +45,7 @@ type PluginStatus struct { // PostHooks are executed in the reverse order of PreHooks. // // Execution order: -// 1. TransportInterceptor (HTTP transport only, modifies raw headers/body before entering Bifrost core) +// 1. HTTPTransportMiddleware (HTTP transport only, modifies raw headers/body before entering Bifrost core) // 2. PreHook (executed in registration order) // 3. Provider call // 4. PostHook (executed in reverse order of PreHooks) @@ -62,11 +72,11 @@ type Plugin interface { // GetName returns the name of the plugin. GetName() string - // TransportInterceptor is called at the HTTP transport layer before requests enter Bifrost core. - // It allows plugins to modify raw HTTP headers and body before transformation into BifrostRequest. + // HTTPTransportMiddleware is called at the HTTP transport layer before requests enter Bifrost core. + // It allows plugins to modify the request and response before they are processed by the next middleware. // Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly. - // Returns modified headers, modified body, and any error that occurred during interception. - TransportInterceptor(ctx *BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + // Returns a new handler that will be called next in the middleware chain. + HTTPTransportMiddleware() BifrostHTTPMiddleware // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. @@ -95,3 +105,33 @@ type PluginConfig struct { Version *int16 `json:"version,omitempty"` Config any `json:"config,omitempty"` } + +// ObservabilityPlugin is an interface for plugins that receive completed traces +// for forwarding to observability backends (e.g., OTEL collectors, Datadog, etc.) +// +// ObservabilityPlugins are called asynchronously after the HTTP response has been +// written to the wire, ensuring they don't add latency to the client response. +// +// Plugins implementing this interface will: +// 1. Continue to work as regular plugins via PreHook/PostHook +// 2. Additionally receive completed traces via the Inject method +// +// Example backends: OpenTelemetry collectors, Datadog, Jaeger, Maxim, etc. +// +// Note: Go type assertion (plugin.(ObservabilityPlugin)) is used to identify +// plugins implementing this interface - no marker method is needed. +type ObservabilityPlugin interface { + Plugin + + // Inject receives a completed trace for forwarding to observability backends. + // This method is called asynchronously after the response has been written to the client. + // The trace contains all spans that were added during request processing. + // + // Implementations should: + // - Convert the trace to their backend's format + // - Send the trace to the backend (can be async) + // - Handle errors gracefully (log and continue) + // + // The context passed is a fresh background context, not the request context. + Inject(ctx context.Context, trace *Trace) error +} diff --git a/core/schemas/trace.go b/core/schemas/trace.go new file mode 100644 index 000000000..6b0b805d4 --- /dev/null +++ b/core/schemas/trace.go @@ -0,0 +1,333 @@ +// Package schemas defines the core schemas and types used by the Bifrost system. +package schemas + +import ( + "sync" + "time" +) + +// Trace represents a distributed trace that captures the full lifecycle of a request +type Trace struct { + TraceID string // Unique identifier for this trace + ParentID string // Parent trace ID from incoming W3C traceparent header + RootSpan *Span // The root span of this trace + Spans []*Span // All spans in this trace + StartTime time.Time // When the trace started + EndTime time.Time // When the trace completed + Attributes map[string]any // Additional attributes for the trace + mu sync.Mutex // Mutex for thread-safe span operations +} + +// AddSpan adds a span to the trace in a thread-safe manner +func (t *Trace) AddSpan(span *Span) { + t.mu.Lock() + defer t.mu.Unlock() + t.Spans = append(t.Spans, span) +} + +// GetSpan retrieves a span by ID +func (t *Trace) GetSpan(spanID string) *Span { + t.mu.Lock() + defer t.mu.Unlock() + for _, span := range t.Spans { + if span.SpanID == spanID { + return span + } + } + return nil +} + +// Reset clears the trace for reuse from pool +func (t *Trace) Reset() { + t.TraceID = "" + t.ParentID = "" + t.RootSpan = nil + t.Spans = t.Spans[:0] + t.StartTime = time.Time{} + t.EndTime = time.Time{} + t.Attributes = nil +} + +// Span represents a single operation within a trace +type Span struct { + SpanID string // Unique identifier for this span + ParentID string // Parent span ID (empty for root span) + TraceID string // The trace this span belongs to + Name string // Name of the operation + Kind SpanKind // Type of span (LLM call, plugin, etc.) + StartTime time.Time // When the span started + EndTime time.Time // When the span completed + Status SpanStatus // Status of the operation + StatusMsg string // Optional status message (for errors) + Attributes map[string]any // Additional attributes for the span + Events []SpanEvent // Events that occurred during the span + mu sync.Mutex // Mutex for thread-safe attribute operations +} + +// SetAttribute sets an attribute on the span in a thread-safe manner +func (s *Span) SetAttribute(key string, value any) { + s.mu.Lock() + defer s.mu.Unlock() + if s.Attributes == nil { + s.Attributes = make(map[string]any) + } + s.Attributes[key] = value +} + +// AddEvent adds an event to the span in a thread-safe manner +func (s *Span) AddEvent(event SpanEvent) { + s.mu.Lock() + defer s.mu.Unlock() + s.Events = append(s.Events, event) +} + +// End marks the span as complete with the given status +func (s *Span) End(status SpanStatus, statusMsg string) { + s.EndTime = time.Now() + s.Status = status + s.StatusMsg = statusMsg +} + +// Reset clears the span for reuse from pool +func (s *Span) Reset() { + s.SpanID = "" + s.ParentID = "" + s.TraceID = "" + s.Name = "" + s.Kind = SpanKindUnspecified + s.StartTime = time.Time{} + s.EndTime = time.Time{} + s.Status = SpanStatusUnset + s.StatusMsg = "" + s.Attributes = nil + s.Events = s.Events[:0] +} + +// SpanEvent represents a time-stamped event within a span +type SpanEvent struct { + Name string // Name of the event + Timestamp time.Time // When the event occurred + Attributes map[string]any // Additional attributes for the event +} + +// SpanKind represents the type of operation a span represents +// These are LLM-specific kinds designed for AI gateway observability +type SpanKind string + +const ( + // SpanKindUnspecified is the default span kind + SpanKindUnspecified SpanKind = "" + // SpanKindLLMCall represents a call to an LLM provider + SpanKindLLMCall SpanKind = "llm.call" + // SpanKindPlugin represents plugin execution (PreHook/PostHook) + SpanKindPlugin SpanKind = "plugin" + // SpanKindMCPTool represents an MCP tool invocation + SpanKindMCPTool SpanKind = "mcp.tool" + // SpanKindRetry represents a retry attempt + SpanKindRetry SpanKind = "retry" + // SpanKindFallback represents a fallback to another provider + SpanKindFallback SpanKind = "fallback" + // SpanKindHTTPRequest represents the root HTTP request span + SpanKindHTTPRequest SpanKind = "http.request" + // SpanKindEmbedding represents an embedding request + SpanKindEmbedding SpanKind = "embedding" + // SpanKindSpeech represents a text-to-speech request + SpanKindSpeech SpanKind = "speech" + // SpanKindTranscription represents a speech-to-text request + SpanKindTranscription SpanKind = "transcription" + // SpanKindInternal represents internal operations (key selection, etc.) + SpanKindInternal SpanKind = "internal" +) + +// SpanStatus represents the status of a span's operation +type SpanStatus string + +const ( + // SpanStatusUnset indicates status has not been set + SpanStatusUnset SpanStatus = "unset" + // SpanStatusOk indicates the operation completed successfully + SpanStatusOk SpanStatus = "ok" + // SpanStatusError indicates the operation failed + SpanStatusError SpanStatus = "error" +) + +// LLM Attribute Keys (gen_ai.* namespace) +// These follow the OpenTelemetry semantic conventions for GenAI +// and are compatible with both OTEL and Datadog backends. +const ( + // Provider and Model Attributes + AttrProviderName = "gen_ai.provider.name" + AttrRequestModel = "gen_ai.request.model" + + // Request Parameter Attributes + AttrMaxTokens = "gen_ai.request.max_tokens" + AttrTemperature = "gen_ai.request.temperature" + AttrTopP = "gen_ai.request.top_p" + AttrStopSequences = "gen_ai.request.stop_sequences" + AttrPresencePenalty = "gen_ai.request.presence_penalty" + AttrFrequencyPenalty = "gen_ai.request.frequency_penalty" + AttrParallelToolCall = "gen_ai.request.parallel_tool_calls" + AttrRequestUser = "gen_ai.request.user" + AttrBestOf = "gen_ai.request.best_of" + AttrEcho = "gen_ai.request.echo" + AttrLogitBias = "gen_ai.request.logit_bias" + AttrLogProbs = "gen_ai.request.logprobs" + AttrN = "gen_ai.request.n" + AttrSeed = "gen_ai.request.seed" + AttrSuffix = "gen_ai.request.suffix" + AttrDimensions = "gen_ai.request.dimensions" + AttrEncodingFormat = "gen_ai.request.encoding_format" + AttrLanguage = "gen_ai.request.language" + AttrPrompt = "gen_ai.request.prompt" + AttrResponseFormat = "gen_ai.request.response_format" + AttrFormat = "gen_ai.request.format" + AttrVoice = "gen_ai.request.voice" + AttrMultiVoiceConfig = "gen_ai.request.multi_voice_config" + AttrInstructions = "gen_ai.request.instructions" + AttrSpeed = "gen_ai.request.speed" + AttrMessageCount = "gen_ai.request.message_count" + + // Response Attributes + AttrResponseID = "gen_ai.response.id" + AttrResponseModel = "gen_ai.response.model" + AttrFinishReason = "gen_ai.response.finish_reason" + AttrSystemFprint = "gen_ai.response.system_fingerprint" + AttrServiceTier = "gen_ai.response.service_tier" + AttrCreated = "gen_ai.response.created" + AttrObject = "gen_ai.response.object" + AttrTimeToFirstToken = "gen_ai.response.time_to_first_token" + AttrTotalChunks = "gen_ai.response.total_chunks" + + // Plugin Attributes (for aggregated streaming post-hook spans) + AttrPluginInvocations = "plugin.invocation_count" + AttrPluginAvgDurationMs = "plugin.avg_duration_ms" + AttrPluginTotalDurationMs = "plugin.total_duration_ms" + AttrPluginErrorCount = "plugin.error_count" + + // Usage Attributes + AttrPromptTokens = "gen_ai.usage.prompt_tokens" + AttrCompletionTokens = "gen_ai.usage.completion_tokens" + AttrTotalTokens = "gen_ai.usage.total_tokens" + AttrInputTokens = "gen_ai.usage.input_tokens" + AttrOutputTokens = "gen_ai.usage.output_tokens" + AttrUsageCost = "gen_ai.usage.cost" + + // Error Attributes + AttrError = "gen_ai.error" + AttrErrorType = "gen_ai.error.type" + AttrErrorCode = "gen_ai.error.code" + + // Input/Output Attributes + AttrInputText = "gen_ai.input.text" + AttrInputMessages = "gen_ai.input.messages" + AttrInputSpeech = "gen_ai.input.speech" + AttrInputEmbedding = "gen_ai.input.embedding" + AttrOutputMessages = "gen_ai.output.messages" + + // Bifrost Context Attributes + AttrVirtualKeyID = "gen_ai.virtual_key_id" + AttrVirtualKeyName = "gen_ai.virtual_key_name" + AttrSelectedKeyID = "gen_ai.selected_key_id" + AttrSelectedKeyName = "gen_ai.selected_key_name" + AttrTeamID = "gen_ai.team_id" + AttrTeamName = "gen_ai.team_name" + AttrCustomerID = "gen_ai.customer_id" + AttrCustomerName = "gen_ai.customer_name" + AttrNumberOfRetries = "gen_ai.number_of_retries" + AttrFallbackIndex = "gen_ai.fallback_index" + + // Responses API Request Attributes + AttrPromptCacheKey = "gen_ai.request.prompt_cache_key" + AttrReasoningEffort = "gen_ai.request.reasoning_effort" + AttrReasoningSummary = "gen_ai.request.reasoning_summary" + AttrReasoningGenSummary = "gen_ai.request.reasoning_generate_summary" + AttrSafetyIdentifier = "gen_ai.request.safety_identifier" + AttrStore = "gen_ai.request.store" + AttrTextVerbosity = "gen_ai.request.text_verbosity" + AttrTextFormatType = "gen_ai.request.text_format_type" + AttrTopLogProbs = "gen_ai.request.top_logprobs" + AttrToolChoiceType = "gen_ai.request.tool_choice_type" + AttrToolChoiceName = "gen_ai.request.tool_choice_name" + AttrTools = "gen_ai.request.tools" + AttrTruncation = "gen_ai.request.truncation" + + // Responses API Response Attributes + AttrRespInclude = "gen_ai.responses.include" + AttrRespMaxOutputTokens = "gen_ai.responses.max_output_tokens" + AttrRespMaxToolCalls = "gen_ai.responses.max_tool_calls" + AttrRespMetadata = "gen_ai.responses.metadata" + AttrRespPreviousRespID = "gen_ai.responses.previous_response_id" + AttrRespPromptCacheKey = "gen_ai.responses.prompt_cache_key" + AttrRespReasoningText = "gen_ai.responses.reasoning" + AttrRespReasoningEffort = "gen_ai.responses.reasoning_effort" + AttrRespReasoningGenSum = "gen_ai.responses.reasoning_generate_summary" + AttrRespSafetyIdentifier = "gen_ai.responses.safety_identifier" + AttrRespStore = "gen_ai.responses.store" + AttrRespTemperature = "gen_ai.responses.temperature" + AttrRespTextVerbosity = "gen_ai.responses.text_verbosity" + AttrRespTextFormatType = "gen_ai.responses.text_format_type" + AttrRespTopLogProbs = "gen_ai.responses.top_logprobs" + AttrRespTopP = "gen_ai.responses.top_p" + AttrRespToolChoiceType = "gen_ai.responses.tool_choice_type" + AttrRespToolChoiceName = "gen_ai.responses.tool_choice_name" + AttrRespTruncation = "gen_ai.responses.truncation" + AttrRespTools = "gen_ai.responses.tools" + + // Batch Operation Attributes + AttrBatchID = "gen_ai.batch.id" + AttrBatchStatus = "gen_ai.batch.status" + AttrBatchObject = "gen_ai.batch.object" + AttrBatchEndpoint = "gen_ai.batch.endpoint" + AttrBatchInputFileID = "gen_ai.batch.input_file_id" + AttrBatchOutputFileID = "gen_ai.batch.output_file_id" + AttrBatchErrorFileID = "gen_ai.batch.error_file_id" + AttrBatchCompletionWin = "gen_ai.batch.completion_window" + AttrBatchCreatedAt = "gen_ai.batch.created_at" + AttrBatchExpiresAt = "gen_ai.batch.expires_at" + AttrBatchRequestsCount = "gen_ai.batch.requests_count" + AttrBatchDataCount = "gen_ai.batch.data_count" + AttrBatchResultsCount = "gen_ai.batch.results_count" + AttrBatchHasMore = "gen_ai.batch.has_more" + AttrBatchMetadata = "gen_ai.batch.metadata" + AttrBatchLimit = "gen_ai.batch.limit" + AttrBatchAfter = "gen_ai.batch.after" + AttrBatchBeforeID = "gen_ai.batch.before_id" + AttrBatchAfterID = "gen_ai.batch.after_id" + AttrBatchPageToken = "gen_ai.batch.page_token" + AttrBatchPageSize = "gen_ai.batch.page_size" + AttrBatchCountTotal = "gen_ai.batch.request_counts.total" + AttrBatchCountCompleted = "gen_ai.batch.request_counts.completed" + AttrBatchCountFailed = "gen_ai.batch.request_counts.failed" + AttrBatchFirstID = "gen_ai.batch.first_id" + AttrBatchLastID = "gen_ai.batch.last_id" + AttrBatchInProgressAt = "gen_ai.batch.in_progress_at" + AttrBatchFinalizingAt = "gen_ai.batch.finalizing_at" + AttrBatchCompletedAt = "gen_ai.batch.completed_at" + AttrBatchFailedAt = "gen_ai.batch.failed_at" + AttrBatchExpiredAt = "gen_ai.batch.expired_at" + AttrBatchCancellingAt = "gen_ai.batch.cancelling_at" + AttrBatchCancelledAt = "gen_ai.batch.cancelled_at" + AttrBatchNextCursor = "gen_ai.batch.next_cursor" + + // Transcription Response Attributes + AttrInputTokenDetailsText = "gen_ai.usage.input_token_details.text_tokens" + AttrInputTokenDetailsAudio = "gen_ai.usage.input_token_details.audio_tokens" + + // File Operation Attributes + AttrFileID = "gen_ai.file.id" + AttrFileObject = "gen_ai.file.object" + AttrFileFilename = "gen_ai.file.filename" + AttrFilePurpose = "gen_ai.file.purpose" + AttrFileBytes = "gen_ai.file.bytes" + AttrFileCreatedAt = "gen_ai.file.created_at" + AttrFileStatus = "gen_ai.file.status" + AttrFileStorageBackend = "gen_ai.file.storage_backend" + AttrFileDataCount = "gen_ai.file.data_count" + AttrFileHasMore = "gen_ai.file.has_more" + AttrFileDeleted = "gen_ai.file.deleted" + AttrFileContentType = "gen_ai.file.content_type" + AttrFileContentBytes = "gen_ai.file.content_bytes" + AttrFileLimit = "gen_ai.file.limit" + AttrFileAfter = "gen_ai.file.after" + AttrFileOrder = "gen_ai.file.order" +) diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go new file mode 100644 index 000000000..533ba8831 --- /dev/null +++ b/core/schemas/tracer.go @@ -0,0 +1,187 @@ +// Package schemas defines the core schemas and types used by the Bifrost system. +package schemas + +import ( + "context" + "time" +) + +// SpanHandle is an opaque handle to a span, implementation-specific. +// Different Tracer implementations can use their own concrete types. +type SpanHandle interface{} + +// StreamAccumulatorResult contains the accumulated data from streaming chunks. +// This is the return type for tracer's streaming accumulation methods. +type StreamAccumulatorResult struct { + IsFinal bool // Whether this is the final chunk + RequestID string // Request ID + Model string // Model used + Provider ModelProvider // Provider used + Status string // Status of the stream + Latency int64 // Latency in milliseconds + TimeToFirstToken int64 // Time to first token in milliseconds + OutputMessage *ChatMessage // Accumulated output message + OutputMessages []ResponsesMessage // For responses API + TokenUsage *BifrostLLMUsage // Token usage + Cost *float64 // Cost in dollars + ErrorDetails *BifrostError // Error details if any + AudioOutput *BifrostSpeechResponse // For speech streaming + TranscriptionOutput *BifrostTranscriptionResponse // For transcription streaming + FinishReason *string // Finish reason + RawResponse *string // Raw response + RawRequest interface{} // Raw request +} + +// Tracer defines the interface for distributed tracing in Bifrost. +// Implementations can be injected via BifrostConfig to enable automatic instrumentation. +// The interface is designed to be minimal and implementation-agnostic. +type Tracer interface { + // CreateTrace creates a new trace with optional parent ID and returns the trace ID. + // The parentID can be extracted from W3C traceparent headers for distributed tracing. + CreateTrace(parentID string) string + + // EndTrace completes a trace and returns the trace data for observation/export. + // After this call, the trace is removed from active tracking and returned for cleanup. + // Returns nil if trace not found. + EndTrace(traceID string) *Trace + + // StartSpan creates a new span as a child of the current span in context. + // Returns updated context with new span and a handle for the span. + // The context should be used for subsequent operations to maintain span hierarchy. + StartSpan(ctx context.Context, name string, kind SpanKind) (context.Context, SpanHandle) + + // EndSpan completes a span with status and optional message. + // Should be called when the operation represented by the span is complete. + EndSpan(handle SpanHandle, status SpanStatus, statusMsg string) + + // SetAttribute sets an attribute on the span. + // Attributes provide additional context about the operation. + SetAttribute(handle SpanHandle, key string, value any) + + // AddEvent adds a timestamped event to the span. + // Events represent discrete occurrences during the span's lifetime. + AddEvent(handle SpanHandle, name string, attrs map[string]any) + + // PopulateLLMRequestAttributes populates all LLM-specific request attributes on the span. + // This includes model parameters, input messages, temperature, max tokens, etc. + PopulateLLMRequestAttributes(handle SpanHandle, req *BifrostRequest) + + // PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. + // This includes output messages, tokens, usage stats, and error information if present. + PopulateLLMResponseAttributes(handle SpanHandle, resp *BifrostResponse, err *BifrostError) + + // StoreDeferredSpan stores a span handle for later completion (used for streaming requests). + // The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. + StoreDeferredSpan(traceID string, handle SpanHandle) + + // GetDeferredSpanHandle retrieves a deferred span handle by trace ID. + // Returns nil if no deferred span exists for the given trace ID. + GetDeferredSpanHandle(traceID string) SpanHandle + + // ClearDeferredSpan removes the deferred span handle for a trace ID. + // Should be called after the deferred span has been completed. + ClearDeferredSpan(traceID string) + + // GetDeferredSpanID returns the span ID for the deferred span. + // Returns empty string if no deferred span exists. + GetDeferredSpanID(traceID string) string + + // AddStreamingChunk accumulates a streaming chunk for the deferred span. + // Pass the full BifrostResponse to capture content, tool calls, reasoning, etc. + // This is called for each streaming chunk to build up the complete response. + AddStreamingChunk(traceID string, response *BifrostResponse) + + // GetAccumulatedChunks returns the accumulated BifrostResponse, TTFT, and chunk count for a deferred span. + // Returns the built response (with content, tool calls, etc.), time-to-first-token in ms, and total chunk count. + // Returns nil, 0, 0 if no accumulated data exists. + GetAccumulatedChunks(traceID string) (response *BifrostResponse, ttftMs int64, chunkCount int) + + // CreateStreamAccumulator creates a new stream accumulator for the given trace ID. + // This should be called at the start of a streaming request. + CreateStreamAccumulator(traceID string, startTime time.Time) + + // CleanupStreamAccumulator removes the stream accumulator for the given trace ID. + // This should be called after the streaming request is complete. + CleanupStreamAccumulator(traceID string) + + // ProcessStreamingChunk processes a streaming chunk and accumulates it. + // Returns the accumulated result. IsFinal will be true when the stream is complete. + // This method is used by plugins to access accumulated streaming data. + // The ctx parameter must contain the stream end indicator for proper final chunk detection. + ProcessStreamingChunk(ctx *BifrostContext, traceID string, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult + + // Stop releases resources associated with the tracer. + // Should be called during shutdown to stop background goroutines. + Stop() +} + +// NoOpTracer is a tracer that does nothing (default when tracing disabled). +// It satisfies the Tracer interface but performs no actual tracing operations. +type NoOpTracer struct{} + +// CreateTrace returns an empty string (no trace created). +func (n *NoOpTracer) CreateTrace(_ string) string { return "" } + +// EndTrace returns nil (no trace to end). +func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil } + +// StartSpan returns the context unchanged and a nil handle. +func (n *NoOpTracer) StartSpan(ctx context.Context, _ string, _ SpanKind) (context.Context, SpanHandle) { + return ctx, nil +} + +// EndSpan does nothing. +func (n *NoOpTracer) EndSpan(_ SpanHandle, _ SpanStatus, _ string) {} + +// SetAttribute does nothing. +func (n *NoOpTracer) SetAttribute(_ SpanHandle, _ string, _ any) {} + +// AddEvent does nothing. +func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {} + +// PopulateLLMRequestAttributes does nothing. +func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {} + +// PopulateLLMResponseAttributes does nothing. +func (n *NoOpTracer) PopulateLLMResponseAttributes(_ SpanHandle, _ *BifrostResponse, _ *BifrostError) {} + +// StoreDeferredSpan does nothing. +func (n *NoOpTracer) StoreDeferredSpan(_ string, _ SpanHandle) {} + +// GetDeferredSpanHandle returns nil. +func (n *NoOpTracer) GetDeferredSpanHandle(_ string) SpanHandle { return nil } + +// ClearDeferredSpan does nothing. +func (n *NoOpTracer) ClearDeferredSpan(_ string) {} + +// GetDeferredSpanID returns empty string. +func (n *NoOpTracer) GetDeferredSpanID(_ string) string { return "" } + +// AddStreamingChunk does nothing. +func (n *NoOpTracer) AddStreamingChunk(_ string, _ *BifrostResponse) {} + +// GetAccumulatedChunks returns nil, 0, 0. +func (n *NoOpTracer) GetAccumulatedChunks(_ string) (*BifrostResponse, int64, int) { return nil, 0, 0 } + +// CreateStreamAccumulator does nothing. +func (n *NoOpTracer) CreateStreamAccumulator(_ string, _ time.Time) {} + +// CleanupStreamAccumulator does nothing. +func (n *NoOpTracer) CleanupStreamAccumulator(_ string) {} + +// ProcessStreamingChunk returns nil. +func (n *NoOpTracer) ProcessStreamingChunk(_ *BifrostContext, _ string, _ *BifrostResponse, _ *BifrostError) *StreamAccumulatorResult { + return nil +} + +// Stop does nothing. +func (n *NoOpTracer) Stop() {} + +// DefaultTracer returns a no-op tracer for use when tracing is disabled. +func DefaultTracer() Tracer { + return &NoOpTracer{} +} + +// Ensure NoOpTracer implements Tracer at compile time +var _ Tracer = (*NoOpTracer)(nil) + diff --git a/core/utils.go b/core/utils.go index e38df6a27..1c074f7f0 100644 --- a/core/utils.go +++ b/core/utils.go @@ -396,3 +396,8 @@ func isPrivateIP(ip net.IP) bool { } return false } + +// sanitizeSpanName sanitizes a span name to remove capital letters and spaces to make it a valid span name +func sanitizeSpanName(name string) string { + return strings.ToLower(strings.ReplaceAll(name, " ", "-")) +} \ No newline at end of file diff --git a/core/version b/core/version index bb3653fe5..589268e6f 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.2.43 \ No newline at end of file +1.3.0 \ No newline at end of file diff --git a/docs/docs.json b/docs/docs.json index 6d23dcfbc..ce98936f0 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -132,7 +132,8 @@ "pages": [ "plugins/getting-started", "plugins/building-dynamic-binary", - "plugins/writing-plugin" + "plugins/writing-plugin", + "plugins/migration-guide" ] }, { diff --git a/docs/plugins/getting-started.mdx b/docs/plugins/getting-started.mdx index ac30079ab..75ad52e05 100644 --- a/docs/plugins/getting-started.mdx +++ b/docs/plugins/getting-started.mdx @@ -48,12 +48,24 @@ go build -buildmode=plugin -o myplugin.so main.go This generates a `.so` file that exports specific functions matching Bifrost's plugin interface: -- `Init(config any) error` - Initialize the plugin with configuration -- `GetName() string` - Return the plugin name -- `PreHook()` - Intercept requests before they reach providers -- `PostHook()` - Process responses after provider calls -- `TransportInterceptor()` - Modify raw HTTP headers/body (HTTP transport only) -- `Cleanup() error` - Clean up resources on shutdown + + + - `Init(config any) error` - Initialize the plugin with configuration + - `GetName() string` - Return the plugin name + - `PreHook()` - Intercept requests before they reach providers + - `PostHook()` - Process responses after provider calls + - `HTTPTransportMiddleware()` - Middleware for HTTP transport layer (HTTP transport only) + - `Cleanup() error` - Clean up resources on shutdown + + + - `Init(config any) error` - Initialize the plugin with configuration + - `GetName() string` - Return the plugin name + - `PreHook()` - Intercept requests before they reach providers + - `PostHook()` - Process responses after provider calls + - `TransportInterceptor()` - Modify raw HTTP headers/body (HTTP transport only) + - `Cleanup() error` - Clean up resources on shutdown + + ### Platform Requirements @@ -74,10 +86,21 @@ This means if you're running Bifrost on Linux AMD64, you must build your plugin 4. **Cleanup** - Calls `Cleanup()` when Bifrost shuts down Plugins execute in a specific order: -1. `TransportInterceptor` - Modifies raw HTTP requests (HTTP transport only) -2. `PreHook` - Executes in registration order, can short-circuit requests -3. Provider call (if not short-circuited) -4. `PostHook` - Executes in reverse order of PreHooks + + + + 1. `HTTPTransportMiddleware` - HTTP transport middleware (HTTP transport only) + 2. `PreHook` - Executes in registration order, can short-circuit requests + 3. Provider call (if not short-circuited) + 4. `PostHook` - Executes in reverse order of PreHooks + + + 1. `TransportInterceptor` - Modifies raw HTTP requests (HTTP transport only) + 2. `PreHook` - Executes in registration order, can short-circuit requests + 3. Provider call (if not short-circuited) + 4. `PostHook` - Executes in reverse order of PreHooks + + ## Next Steps diff --git a/docs/plugins/migration-guide.mdx b/docs/plugins/migration-guide.mdx new file mode 100644 index 000000000..20313d708 --- /dev/null +++ b/docs/plugins/migration-guide.mdx @@ -0,0 +1,308 @@ +--- +title: "Plugin Migration Guide" +description: "How to migrate your Bifrost plugins from v1.3.x to v1.4.x" +icon: "arrow-up-right-dots" +--- + +## Overview + +Bifrost v1.4.x introduces a new plugin interface for HTTP transport layer interception. This guide helps you migrate existing plugins from the v1.3.x `TransportInterceptor` pattern to the v1.4.x `HTTPTransportMiddleware` pattern. + + +If your plugin doesn't use `TransportInterceptor`, no migration is needed. The `PreHook`, `PostHook`, `Init`, `GetName`, and `Cleanup` functions remain unchanged. + + +## What Changed? + +The HTTP transport interception mechanism changed from a simple function that receives and returns headers/body to a middleware pattern that wraps the entire request handler chain. + +### Key Differences + +| Aspect | v1.3.x (TransportInterceptor) | v1.4.x+ (HTTPTransportMiddleware) | +|--------|-------------------------------|-----------------------------------| +| Function signature | `TransportInterceptor(ctx, url, headers, body)` | `HTTPTransportMiddleware()` | +| Return type | `(headers, body, error)` | `BifrostHTTPMiddleware` | +| Access scope | Headers and body as maps | Full `*fasthttp.RequestCtx` | +| Flow control | Implicit (return modified values) | Explicit (`next(ctx)` call) | +| Capability | Request modification only | Request AND response modification | + +### Why the Change? + +The new middleware pattern provides: + +1. **Full HTTP control** - Access to the complete `*fasthttp.RequestCtx` including method, path, query params, and all headers +2. **Response interception** - Ability to modify responses after they return from downstream handlers +3. **Better composability** - Standard middleware pattern that chains naturally with other handlers +4. **More flexibility** - Can short-circuit requests, add timing, implement retries, etc. + +## Migration Steps + +### Step 1: Update Imports + +Add the `fasthttp` import to your plugin: + +```go +import ( + "fmt" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" // Add this import +) +``` + +### Step 2: Replace the Function + +**Before (v1.3.x):** + +```go +// TransportInterceptor modifies raw HTTP headers and body +func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + // Add custom header + headers["X-Custom-Header"] = "value" + + // Modify body + body["custom_field"] = "custom_value" + + return headers, body, nil +} +``` + +**After (v1.4.x+):** + +```go +// HTTPTransportMiddleware returns a middleware for HTTP transport +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Add custom header + ctx.Request.Header.Set("X-Custom-Header", "value") + + // Modify body (if needed) + // Note: Body modification requires parsing and re-serializing + // ctx.Request.SetBody(modifiedBody) + + // Call next handler in chain + next(ctx) + + // Can also modify response here after next() returns + } + } +} +``` + +### Step 3: Update Body Modification Logic + +In v1.3.x, you received the body as a `map[string]any`. In v1.4.x, you work with raw bytes: + +**Before (v1.3.x):** + +```go +func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + // Direct map access + body["model"] = "gpt-4" + return headers, body, nil +} +``` + +**After (v1.4.x+):** + +```go +import "github.com/bytedance/sonic" + +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Parse existing body + var body map[string]any + if err := sonic.Unmarshal(ctx.Request.Body(), &body); err == nil { + // Modify body + body["model"] = "gpt-4" + + // Re-serialize and set + if newBody, err := sonic.Marshal(body); err == nil { + ctx.Request.SetBody(newBody) + } + } + + next(ctx) + } + } +} +``` + +### Step 4: Handle Response Modification (New Capability) + +The new pattern allows you to modify responses, which wasn't possible in v1.3.x: + +```go +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Before request + startTime := time.Now() + + // Process request + next(ctx) + + // After response - NEW CAPABILITY + duration := time.Since(startTime) + ctx.Response.Header.Set("X-Processing-Time", duration.String()) + } + } +} +``` + +## Common Migration Patterns + +### Adding Headers + +**v1.3.x:** +```go +headers["Authorization"] = "Bearer " + token +return headers, body, nil +``` + +**v1.4.x+:** +```go +ctx.Request.Header.Set("Authorization", "Bearer " + token) +next(ctx) +``` + +### Reading Headers + +**v1.3.x:** +```go +apiKey := headers["X-API-Key"] +``` + +**v1.4.x+:** +```go +apiKey := string(ctx.Request.Header.Peek("X-API-Key")) +``` + +### Conditional Processing + +**v1.3.x:** +```go +func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + if headers["X-Skip-Processing"] == "true" { + return headers, body, nil + } + // Process... + return headers, body, nil +} +``` + +**v1.4.x+:** +```go +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + if string(ctx.Request.Header.Peek("X-Skip-Processing")) == "true" { + next(ctx) + return + } + // Process... + next(ctx) + } + } +} +``` + +### Error Handling + +**v1.3.x:** +```go +func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + if headers["X-API-Key"] == "" { + return nil, nil, fmt.Errorf("missing API key") + } + return headers, body, nil +} +``` + +**v1.4.x+:** +```go +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + if len(ctx.Request.Header.Peek("X-API-Key")) == 0 { + ctx.SetStatusCode(401) + ctx.SetBodyString(`{"error": "missing API key"}`) + return // Don't call next - short-circuit the request + } + next(ctx) + } + } +} +``` + +## Testing Your Migration + +1. **Build your updated plugin:** + ```bash + go build -buildmode=plugin -o my-plugin.so main.go + ``` + +2. **Update Bifrost to v1.4.x:** + ```bash + go get github.com/maximhq/bifrost/core@v1.4.0 + ``` + +3. **Test with a simple request:** + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello"}]}' + ``` + +4. **Verify logs show the new middleware being called:** + ``` + HTTPTransportMiddleware called + PreHook called + PostHook called + ``` + +## Troubleshooting + +### Plugin fails to load after migration + +**Error:** `plugin: symbol TransportInterceptor not found` + +This error occurs if Bifrost v1.4.x is looking for the old function. Make sure: +1. You've updated to `HTTPTransportMiddleware` +2. The function signature matches exactly +3. You've rebuilt the plugin with the correct core version + +### Body modification not working + +Make sure you're calling `ctx.Request.SetBody()` with the serialized bytes, not the map directly: + +```go +// Wrong +ctx.Request.SetBody(body) // body is map[string]any + +// Correct +bodyBytes, _ := sonic.Marshal(body) +ctx.Request.SetBody(bodyBytes) +``` + +### Headers not being set + +Remember that `fasthttp` header methods are case-sensitive for custom headers: + +```go +// Set header +ctx.Request.Header.Set("X-Custom-Header", "value") + +// Read header - use Peek for []byte or string conversion +value := string(ctx.Request.Header.Peek("X-Custom-Header")) +``` + +## Need Help? + +- **Discord Community**: [Join our Discord](https://getmax.im/bifrost-discord) +- **GitHub Issues**: [Report bugs or request features](https://github.com/maximhq/bifrost/issues) +- **Writing Plugins Guide**: [Full plugin documentation](./writing-plugin) + + diff --git a/docs/plugins/writing-plugin.mdx b/docs/plugins/writing-plugin.mdx index 8eb71bf5c..cdc0dd2f5 100644 --- a/docs/plugins/writing-plugin.mdx +++ b/docs/plugins/writing-plugin.mdx @@ -62,6 +62,70 @@ require github.com/maximhq/bifrost/core v1.2.38 Create `main.go` with the required plugin functions. Here's the complete hello-world example: + + +```go +package main + +import ( + "fmt" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// Init is called when the plugin is loaded +// config contains the plugin configuration from config.json +func Init(config any) error { + fmt.Println("Init called") + // Initialize your plugin here (database connections, API clients, etc.) + return nil +} + +// GetName returns the plugin's unique identifier +func GetName() string { + return "Hello World Plugin" +} + +// HTTPTransportMiddleware returns a middleware for HTTP transport +// Only called when using HTTP transport (bifrost-http) +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + fmt.Println("HTTPTransportMiddleware called") + // Modify request headers/body via ctx.Request before calling next + // Call next handler in the chain + next(ctx) + // Can also modify response via ctx.Response after next returns + } + } +} + +// PreHook is called before the request is sent to the provider +// This is where you can modify requests or short-circuit the flow +func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + fmt.Println("PreHook called") + // Modify the request or return a short-circuit to skip provider call + return req, nil, nil +} + +// PostHook is called after receiving a response from the provider +// This is where you can modify responses or handle errors +func PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + fmt.Println("PostHook called") + // Modify the response or error before returning to caller + return resp, bifrostErr, nil +} + +// Cleanup is called when Bifrost shuts down +func Cleanup() error { + fmt.Println("Cleanup called") + // Clean up resources (close connections, flush buffers, etc.) + return nil +} +``` + + ```go package main @@ -115,6 +179,8 @@ func Cleanup() error { return nil } ``` + + ### Understanding Each Function @@ -144,6 +210,23 @@ func Init(config any) error { Returns a unique identifier for your plugin. This name appears in logs and status reports. + + +#### `HTTPTransportMiddleware()` + +**HTTP transport only.** Returns a middleware that wraps the HTTP request handler chain. Use this to: +- Intercept and modify requests before they enter Bifrost core +- Intercept and modify responses before they're returned to clients +- Implement authentication or logging at the transport layer +- Access raw `*fasthttp.RequestCtx` for full HTTP control + +The middleware pattern requires calling `next(ctx)` to pass control to subsequent handlers. + + +This function is **only called** when using `bifrost-http`. It's **not invoked** when using Bifrost as a Go SDK. + + + #### `TransportInterceptor(...)` **HTTP transport only.** Called before requests enter Bifrost core. Use this to: @@ -154,6 +237,8 @@ Returns a unique identifier for your plugin. This name appears in logs and statu This function is **only called** when using `bifrost-http`. It's **not invoked** when using Bifrost as a Go SDK. + + #### `PreHook(...)` @@ -341,11 +426,22 @@ curl -X POST http://localhost:8080/v1/chat/completions \ Check the logs for plugin hook calls: + + +``` +HTTPTransportMiddleware called +PreHook called +PostHook called +``` + + ``` TransportInterceptor called PreHook called PostHook called ``` + + ## Advanced Plugin Patterns diff --git a/examples/plugins/hello-world/go.mod b/examples/plugins/hello-world/go.mod index 14d567bb4..c8eb6d5ab 100644 --- a/examples/plugins/hello-world/go.mod +++ b/examples/plugins/hello-world/go.mod @@ -2,15 +2,21 @@ module github.com/maximhq/bifrost/examples/plugins/hello-world go 1.25.5 -require github.com/maximhq/bifrost/core v1.2.43 +require ( + github.com/maximhq/bifrost/core v1.2.42 + github.com/valyala/fasthttp v1.68.0 +) require ( + github.com/andybalholm/brotli v1.2.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/arch v0.23.0 // indirect golang.org/x/sys v0.39.0 // indirect ) diff --git a/examples/plugins/hello-world/go.sum b/examples/plugins/hello-world/go.sum index 566d73d80..a70cbe0cc 100644 --- a/examples/plugins/hello-world/go.sum +++ b/examples/plugins/hello-world/go.sum @@ -1,3 +1,5 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= @@ -10,6 +12,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/maximhq/bifrost/core v1.2.43 h1:NxtvzvLL0Isaf8mD1dlGb9JxT7/PZNPv5NBTnqHG100= @@ -29,6 +33,12 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= diff --git a/examples/plugins/hello-world/main.go b/examples/plugins/hello-world/main.go index fb9c7d769..0872b591e 100644 --- a/examples/plugins/hello-world/main.go +++ b/examples/plugins/hello-world/main.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" ) func Init(config any) error { @@ -15,10 +16,14 @@ func GetName() string { return "Hello World Plugin" } -func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - fmt.Println("TransportInterceptor called") - ctx.SetValue(schemas.BifrostContextKey("hello-world-plugin-transport-interceptor"), "transport-interceptor-value") - return headers, body, nil +func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + fmt.Println("HTTPTransportMiddleware called") + ctx.SetUserValue(schemas.BifrostContextKey("hello-world-plugin-transport-interceptor"), "transport-interceptor-value") + next(ctx) + } + } } func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb..3140b17d8 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1,25 @@ +- feat: adds new tracing framework for allowing plugins to enable e2e tracing + +### BREAKING CHANGES + +- **DynamicPlugin: TransportInterceptor replaced with HTTPTransportMiddleware** + + The `DynamicPlugin` loader now expects plugins to export `HTTPTransportMiddleware` instead of `TransportInterceptor`. + + **Old symbol lookup (removed in framework v1.2.0):** + ```go + plugin.Lookup("TransportInterceptor") + // Expected: func(ctx *BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + ``` + + **New symbol lookup (framework v1.2.0+):** + ```go + plugin.Lookup("HTTPTransportMiddleware") + // Expected: func() BifrostHTTPMiddleware + ``` + + **Impact on dynamic plugins (.so files):** + - Plugins compiled for core v1.2.x will fail to load with error: `plugin: symbol HTTPTransportMiddleware not found` + - Recompile all dynamic plugins against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for migration instructions. \ No newline at end of file diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index b991a545a..687c60355 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -79,7 +79,6 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { } else { c.HeadersJSON = "{}" } - return nil } diff --git a/framework/plugins/dynamicplugin.go b/framework/plugins/dynamicplugin.go index 7e8a561ce..7f051fd96 100644 --- a/framework/plugins/dynamicplugin.go +++ b/framework/plugins/dynamicplugin.go @@ -21,11 +21,11 @@ type DynamicPlugin struct { filename string plugin *plugin.Plugin - getName func() string - transportInterceptor func(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) - preHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) - postHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) - cleanup func() error + getName func() string + httpTransportMiddleware func() schemas.BifrostHTTPMiddleware + preHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) + postHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) + cleanup func() error } // GetName returns the name of the plugin @@ -33,9 +33,9 @@ func (dp *DynamicPlugin) GetName() string { return dp.getName() } -// TransportInterceptor is not used for dynamic plugins -func (dp *DynamicPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return dp.transportInterceptor(ctx, url, headers, body) +// HTTPTransportMiddleware returns the HTTP transport middleware function for this plugin +func (dp *DynamicPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return dp.httpTransportMiddleware() } // PreHook is not used for dynamic plugins @@ -137,13 +137,13 @@ func loadDynamicPlugin(path string, config any) (schemas.Plugin, error) { if dp.getName, ok = getNameSym.(func() string); !ok { return nil, fmt.Errorf("failed to cast GetName to func() string") } - // Looking up for TransportInterceptor method - transportInterceptorSym, err := plugin.Lookup("TransportInterceptor") + // Looking up for HTTPTransportMiddleware method + httpTransportMiddlewareSym, err := plugin.Lookup("HTTPTransportMiddleware") if err != nil { return nil, err } - if dp.transportInterceptor, ok = transportInterceptorSym.(func(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)); !ok { - return nil, fmt.Errorf("failed to cast TransportInterceptor to func(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)") + if dp.httpTransportMiddleware, ok = httpTransportMiddlewareSym.(func() schemas.BifrostHTTPMiddleware); !ok { + return nil, fmt.Errorf("failed to cast HTTPTransportMiddleware to func() fasthttp.RequestHandler") } // Looking up for PreHook method preHookSym, err := plugin.Lookup("PreHook") diff --git a/framework/plugins/dynamicplugin_test.go b/framework/plugins/dynamicplugin_test.go index bfba8d567..2b3757743 100644 --- a/framework/plugins/dynamicplugin_test.go +++ b/framework/plugins/dynamicplugin_test.go @@ -13,6 +13,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" ) const ( @@ -50,26 +51,36 @@ func TestDynamicPluginLifecycle(t *testing.T) { assert.Equal(t, "Hello World Plugin", name, "Plugin name should match") }) - // Test TransportInterceptor - t.Run("TransportInterceptor", func(t *testing.T) { - ctx := context.Background() - url := "http://example.com/api" - headers := map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer token123", - } - body := map[string]any{ - "model": "gpt-4", - "messages": []map[string]string{ - {"role": "user", "content": "Hello"}, - }, + // Test HTTPTransportMiddleware + t.Run("HTTPTransportMiddleware", func(t *testing.T) { + // Track if the next handler was called + nextHandlerCalled := false + + // Create a mock next handler + nextHandler := func(ctx *fasthttp.RequestCtx) { + nextHandlerCalled = true } - pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) - defer cancel() - modifiedHeaders, modifiedBody, err := plugin.TransportInterceptor(pluginCtx, url, headers, body) - require.NoError(t, err, "TransportInterceptor should not return error") - assert.Equal(t, headers, modifiedHeaders, "Headers should be unchanged") - assert.Equal(t, body, modifiedBody, "Body should be unchanged") + + // Get the middleware function + middleware := plugin.HTTPTransportMiddleware() + require.NotNil(t, middleware, "HTTPTransportMiddleware should return a middleware function") + + // Wrap the next handler with the middleware + wrappedHandler := middleware(nextHandler) + require.NotNil(t, wrappedHandler, "Middleware should return a wrapped handler") + + // Create a test request context + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("http://example.com/api") + ctx.Request.Header.SetMethod("POST") + ctx.Request.Header.Set("Content-Type", "application/json") + ctx.Request.Header.Set("Authorization", "Bearer token123") + + // Call the wrapped handler + wrappedHandler(ctx) + + // Verify the next handler was called + assert.True(t, nextHandlerCalled, "Next handler should have been called") }) // Test PreHook diff --git a/framework/streaming/accumulator.go b/framework/streaming/accumulator.go index b19bfde0c..b3581b8c4 100644 --- a/framework/streaming/accumulator.go +++ b/framework/streaming/accumulator.go @@ -101,14 +101,17 @@ func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) { a.responsesStreamChunkPool.Put(chunk) } -// CreateStreamAccumulator creates a new stream accumulator for a request +// createStreamAccumulator creates a new stream accumulator for a request +// StartTimestamp is set to current time if not provided via CreateStreamAccumulator func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulator { + now := time.Now() sc := &StreamAccumulator{ RequestID: requestID, ChatStreamChunks: make([]*ChatStreamChunk, 0), ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0), IsComplete: false, - Timestamp: time.Now(), + Timestamp: now, + StartTimestamp: now, // Set default StartTimestamp for proper TTFT/latency calculation } a.streamAccumulators.Store(requestID, sc) return sc @@ -132,6 +135,10 @@ func (a *Accumulator) addChatStreamChunk(requestID string, chunk *ChatStreamChun if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } // Add chunk to the list (chunks arrive in order) accumulator.ChatStreamChunks = append(accumulator.ChatStreamChunks, chunk) // Check if this is the final chunk @@ -152,6 +159,10 @@ func (a *Accumulator) addTranscriptionStreamChunk(requestID string, chunk *Trans if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } // Add chunk to the list (chunks arrive in order) accumulator.TranscriptionStreamChunks = append(accumulator.TranscriptionStreamChunks, chunk) // Check if this is the final chunk @@ -172,6 +183,10 @@ func (a *Accumulator) addAudioStreamChunk(requestID string, chunk *AudioStreamCh if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } // Add chunk to the list (chunks arrive in order) accumulator.AudioStreamChunks = append(accumulator.AudioStreamChunks, chunk) // Check if this is the final chunk @@ -192,6 +207,10 @@ func (a *Accumulator) addResponsesStreamChunk(requestID string, chunk *Responses if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } // Add chunk to the list (chunks arrive in order) accumulator.ResponsesStreamChunks = append(accumulator.ResponsesStreamChunks, chunk) // Check if this is the final chunk diff --git a/framework/streaming/audio.go b/framework/streaming/audio.go index 5123d1ede..28c1d9e8f 100644 --- a/framework/streaming/audio.go +++ b/framework/streaming/audio.go @@ -35,19 +35,27 @@ func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, b } accumulator.mu.Unlock() }() + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessage: nil, - ToolCalls: nil, - ErrorDetails: nil, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } completeMessage := a.buildCompleteMessageFromAudioStreamChunks(accumulator.AudioStreamChunks) if !isFinalChunk { diff --git a/framework/streaming/chat.go b/framework/streaming/chat.go index 36f3b9717..42e9ba608 100644 --- a/framework/streaming/chat.go +++ b/framework/streaming/chat.go @@ -153,20 +153,28 @@ func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, re } accumulator.mu.Unlock() }() + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + // Initialize accumulated data data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessage: nil, - ToolCalls: nil, - ErrorDetails: nil, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } // Build complete message from accumulated chunks completeMessage := a.buildCompleteMessageFromChatStreamChunks(accumulator.ChatStreamChunks) diff --git a/framework/streaming/responses.go b/framework/streaming/responses.go index ca418573c..7e192d346 100644 --- a/framework/streaming/responses.go +++ b/framework/streaming/responses.go @@ -739,20 +739,27 @@ func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID strin accumulator.mu.Unlock() }() + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + // Initialize accumulated data data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessages: nil, - ToolCalls: nil, - ErrorDetails: respErr, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessages: nil, + ToolCalls: nil, + ErrorDetails: respErr, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } // Build complete messages from accumulated chunks diff --git a/framework/streaming/transcription.go b/framework/streaming/transcription.go index 314f4cdb5..91cb4789b 100644 --- a/framework/streaming/transcription.go +++ b/framework/streaming/transcription.go @@ -41,19 +41,27 @@ func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID s } accumulator.mu.Unlock() }() + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessage: nil, - ToolCalls: nil, - ErrorDetails: nil, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } // Build complete message from accumulated chunks completeMessage := a.buildCompleteMessageFromTranscriptionStreamChunks(accumulator.TranscriptionStreamChunks) diff --git a/framework/streaming/types.go b/framework/streaming/types.go index 29bb62dfc..47a9df575 100644 --- a/framework/streaming/types.go +++ b/framework/streaming/types.go @@ -31,6 +31,7 @@ type AccumulatedData struct { Status string Stream bool Latency int64 // in milliseconds + TimeToFirstToken int64 // Time to first token in milliseconds (streaming only) StartTimestamp time.Time EndTimestamp time.Time OutputMessage *schemas.ChatMessage @@ -102,6 +103,7 @@ type ResponsesStreamChunk struct { type StreamAccumulator struct { RequestID string StartTimestamp time.Time + FirstChunkTimestamp time.Time // Timestamp when the first chunk was received (for TTFT calculation) ChatStreamChunks []*ChatStreamChunk ResponsesStreamChunks []*ResponsesStreamChunk TranscriptionStreamChunks []*TranscriptionStreamChunk diff --git a/framework/tracing/helpers.go b/framework/tracing/helpers.go new file mode 100644 index 000000000..edb9a7310 --- /dev/null +++ b/framework/tracing/helpers.go @@ -0,0 +1,83 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "context" + + "github.com/maximhq/bifrost/core/schemas" +) + +// GetTraceID retrieves the trace ID from the context +func GetTraceID(ctx context.Context) string { + if ctx == nil { + return "" + } + traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if !ok { + return "" + } + return traceID +} + +// GetTrace retrieves the current trace from context using the store +func GetTrace(ctx context.Context, store *TraceStore) *schemas.Trace { + traceID := GetTraceID(ctx) + if traceID == "" { + return nil + } + return store.GetTrace(traceID) +} + +// AddSpan adds a new span to the current trace +func AddSpan(ctx context.Context, store *TraceStore, name string, kind schemas.SpanKind) *schemas.Span { + traceID := GetTraceID(ctx) + if traceID == "" { + return nil + } + return store.StartSpan(traceID, name, kind) +} + +// AddChildSpan adds a new child span to the current trace under a specific parent +func AddChildSpan(ctx context.Context, store *TraceStore, parentSpanID, name string, kind schemas.SpanKind) *schemas.Span { + traceID := GetTraceID(ctx) + if traceID == "" { + return nil + } + return store.StartChildSpan(traceID, parentSpanID, name, kind) +} + +// EndSpan completes a span with the given status +func EndSpan(ctx context.Context, store *TraceStore, spanID string, status schemas.SpanStatus, statusMsg string, attrs map[string]any) { + traceID := GetTraceID(ctx) + if traceID == "" { + return + } + store.EndSpan(traceID, spanID, status, statusMsg, attrs) +} + +// SetSpanAttribute sets an attribute on a span +func SetSpanAttribute(ctx context.Context, store *TraceStore, spanID, key string, value any) { + trace := GetTrace(ctx, store) + if trace == nil { + return + } + span := trace.GetSpan(spanID) + if span == nil { + return + } + span.SetAttribute(key, value) +} + +// AddSpanEvent adds an event to a span +func AddSpanEvent(ctx context.Context, store *TraceStore, spanID string, event schemas.SpanEvent) { + trace := GetTrace(ctx, store) + if trace == nil { + return + } + span := trace.GetSpan(spanID) + if span == nil { + return + } + span.AddEvent(event) +} + diff --git a/framework/tracing/llmspan.go b/framework/tracing/llmspan.go new file mode 100644 index 000000000..86bdfca1f --- /dev/null +++ b/framework/tracing/llmspan.go @@ -0,0 +1,1343 @@ +// Package tracing provides distributed tracing utilities for Bifrost. +package tracing + +import ( + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" +) + +// PopulateRequestAttributes extracts common request attributes from a BifrostRequest. +// This is the main entry point for populating request attributes on a span. +func PopulateRequestAttributes(req *schemas.BifrostRequest) map[string]any { + attrs := make(map[string]any) + if req == nil { + return attrs + } + + provider, model, _ := req.GetRequestFields() + attrs[schemas.AttrProviderName] = string(provider) + attrs[schemas.AttrRequestModel] = model + + switch req.RequestType { + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + PopulateChatRequestAttributes(req.ChatRequest, attrs) + case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: + PopulateTextCompletionRequestAttributes(req.TextCompletionRequest, attrs) + case schemas.EmbeddingRequest: + PopulateEmbeddingRequestAttributes(req.EmbeddingRequest, attrs) + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + PopulateTranscriptionRequestAttributes(req.TranscriptionRequest, attrs) + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + PopulateSpeechRequestAttributes(req.SpeechRequest, attrs) + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + PopulateResponsesRequestAttributes(req.ResponsesRequest, attrs) + case schemas.BatchCreateRequest: + PopulateBatchCreateRequestAttributes(req.BatchCreateRequest, attrs) + case schemas.BatchListRequest: + PopulateBatchListRequestAttributes(req.BatchListRequest, attrs) + case schemas.BatchRetrieveRequest: + PopulateBatchRetrieveRequestAttributes(req.BatchRetrieveRequest, attrs) + case schemas.BatchCancelRequest: + PopulateBatchCancelRequestAttributes(req.BatchCancelRequest, attrs) + case schemas.BatchResultsRequest: + PopulateBatchResultsRequestAttributes(req.BatchResultsRequest, attrs) + case schemas.FileUploadRequest: + PopulateFileUploadRequestAttributes(req.FileUploadRequest, attrs) + case schemas.FileListRequest: + PopulateFileListRequestAttributes(req.FileListRequest, attrs) + case schemas.FileRetrieveRequest: + PopulateFileRetrieveRequestAttributes(req.FileRetrieveRequest, attrs) + case schemas.FileDeleteRequest: + PopulateFileDeleteRequestAttributes(req.FileDeleteRequest, attrs) + case schemas.FileContentRequest: + PopulateFileContentRequestAttributes(req.FileContentRequest, attrs) + } + + return attrs +} + +// PopulateResponseAttributes extracts common response attributes from a BifrostResponse. +// This is the main entry point for populating response attributes on a span. +func PopulateResponseAttributes(resp *schemas.BifrostResponse) map[string]any { + attrs := make(map[string]any) + if resp == nil { + return attrs + } + + switch { + case resp.ChatResponse != nil: + PopulateChatResponseAttributes(resp.ChatResponse, attrs) + case resp.TextCompletionResponse != nil: + PopulateTextCompletionResponseAttributes(resp.TextCompletionResponse, attrs) + case resp.EmbeddingResponse != nil: + PopulateEmbeddingResponseAttributes(resp.EmbeddingResponse, attrs) + case resp.TranscriptionResponse != nil: + PopulateTranscriptionResponseAttributes(resp.TranscriptionResponse, attrs) + case resp.SpeechResponse != nil: + PopulateSpeechResponseAttributes(resp.SpeechResponse, attrs) + case resp.ResponsesResponse != nil: + PopulateResponsesResponseAttributes(resp.ResponsesResponse, attrs) + case resp.BatchCreateResponse != nil: + PopulateBatchCreateResponseAttributes(resp.BatchCreateResponse, attrs) + case resp.BatchListResponse != nil: + PopulateBatchListResponseAttributes(resp.BatchListResponse, attrs) + case resp.BatchRetrieveResponse != nil: + PopulateBatchRetrieveResponseAttributes(resp.BatchRetrieveResponse, attrs) + case resp.BatchCancelResponse != nil: + PopulateBatchCancelResponseAttributes(resp.BatchCancelResponse, attrs) + case resp.BatchResultsResponse != nil: + PopulateBatchResultsResponseAttributes(resp.BatchResultsResponse, attrs) + case resp.FileUploadResponse != nil: + PopulateFileUploadResponseAttributes(resp.FileUploadResponse, attrs) + case resp.FileListResponse != nil: + PopulateFileListResponseAttributes(resp.FileListResponse, attrs) + case resp.FileRetrieveResponse != nil: + PopulateFileRetrieveResponseAttributes(resp.FileRetrieveResponse, attrs) + case resp.FileDeleteResponse != nil: + PopulateFileDeleteResponseAttributes(resp.FileDeleteResponse, attrs) + case resp.FileContentResponse != nil: + PopulateFileContentResponseAttributes(resp.FileContentResponse, attrs) + } + + return attrs +} + +// PopulateErrorAttributes extracts error attributes from a BifrostError. +func PopulateErrorAttributes(err *schemas.BifrostError) map[string]any { + attrs := make(map[string]any) + if err == nil || err.Error == nil { + return attrs + } + + attrs[schemas.AttrError] = err.Error.Message + if err.Error.Type != nil { + attrs[schemas.AttrErrorType] = *err.Error.Type + } + if err.Error.Code != nil { + attrs[schemas.AttrErrorCode] = *err.Error.Code + } + + return attrs +} + +// PopulateContextAttributes extracts context-related attributes (virtual keys, retries, etc.) +func PopulateContextAttributes( + attrs map[string]any, + virtualKeyID, virtualKeyName string, + selectedKeyID, selectedKeyName string, + teamID, teamName string, + customerID, customerName string, + numberOfRetries, fallbackIndex int, +) { + if virtualKeyID != "" { + attrs[schemas.AttrVirtualKeyID] = virtualKeyID + attrs[schemas.AttrVirtualKeyName] = virtualKeyName + } + if selectedKeyID != "" { + attrs[schemas.AttrSelectedKeyID] = selectedKeyID + attrs[schemas.AttrSelectedKeyName] = selectedKeyName + } + if teamID != "" { + attrs[schemas.AttrTeamID] = teamID + attrs[schemas.AttrTeamName] = teamName + } + if customerID != "" { + attrs[schemas.AttrCustomerID] = customerID + attrs[schemas.AttrCustomerName] = customerName + } + attrs[schemas.AttrNumberOfRetries] = numberOfRetries + attrs[schemas.AttrFallbackIndex] = fallbackIndex +} + +// =============================================== +// Chat Completion Request/Response +// =============================================== + +// PopulateChatRequestAttributes extracts chat completion request attributes. +func PopulateChatRequestAttributes(req *schemas.BifrostChatRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.MaxCompletionTokens != nil { + attrs[schemas.AttrMaxTokens] = *req.Params.MaxCompletionTokens + } + if req.Params.Temperature != nil { + attrs[schemas.AttrTemperature] = *req.Params.Temperature + } + if req.Params.TopP != nil { + attrs[schemas.AttrTopP] = *req.Params.TopP + } + if req.Params.Stop != nil { + attrs[schemas.AttrStopSequences] = strings.Join(req.Params.Stop, ",") + } + if req.Params.PresencePenalty != nil { + attrs[schemas.AttrPresencePenalty] = *req.Params.PresencePenalty + } + if req.Params.FrequencyPenalty != nil { + attrs[schemas.AttrFrequencyPenalty] = *req.Params.FrequencyPenalty + } + if req.Params.ParallelToolCalls != nil { + attrs[schemas.AttrParallelToolCall] = *req.Params.ParallelToolCalls + } + if req.Params.User != nil { + attrs[schemas.AttrRequestUser] = *req.Params.User + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } + } + + // Extract input messages + if req.Input != nil { + attrs[schemas.AttrMessageCount] = len(req.Input) + messages := extractChatMessages(req.Input) + if len(messages) > 0 { + attrs[schemas.AttrInputMessages] = messages + } + } +} + +// PopulateChatResponseAttributes extracts chat completion response attributes. +func PopulateChatResponseAttributes(resp *schemas.BifrostChatResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrResponseID] = resp.ID + attrs[schemas.AttrResponseModel] = resp.Model + if resp.Object != "" { + attrs[schemas.AttrObject] = resp.Object + } + if resp.SystemFingerprint != "" { + attrs[schemas.AttrSystemFprint] = resp.SystemFingerprint + } + attrs[schemas.AttrCreated] = resp.Created + if resp.ServiceTier != nil { + attrs[schemas.AttrServiceTier] = *resp.ServiceTier + } + + // Extract output messages + outputMessages := extractChatResponseMessages(resp) + if len(outputMessages) > 0 { + attrs[schemas.AttrOutputMessages] = outputMessages + } + + // Extract finish reason from first choice + if len(resp.Choices) > 0 && resp.Choices[0].FinishReason != nil { + attrs[schemas.AttrFinishReason] = *resp.Choices[0].FinishReason + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrPromptTokens] = resp.Usage.PromptTokens + attrs[schemas.AttrCompletionTokens] = resp.Usage.CompletionTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Text Completion Request/Response +// =============================================== + +// PopulateTextCompletionRequestAttributes extracts text completion request attributes. +func PopulateTextCompletionRequestAttributes(req *schemas.BifrostTextCompletionRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.MaxTokens != nil { + attrs[schemas.AttrMaxTokens] = *req.Params.MaxTokens + } + if req.Params.Temperature != nil { + attrs[schemas.AttrTemperature] = *req.Params.Temperature + } + if req.Params.TopP != nil { + attrs[schemas.AttrTopP] = *req.Params.TopP + } + if req.Params.Stop != nil { + attrs[schemas.AttrStopSequences] = strings.Join(req.Params.Stop, ",") + } + if req.Params.PresencePenalty != nil { + attrs[schemas.AttrPresencePenalty] = *req.Params.PresencePenalty + } + if req.Params.FrequencyPenalty != nil { + attrs[schemas.AttrFrequencyPenalty] = *req.Params.FrequencyPenalty + } + if req.Params.BestOf != nil { + attrs[schemas.AttrBestOf] = *req.Params.BestOf + } + if req.Params.Echo != nil { + attrs[schemas.AttrEcho] = *req.Params.Echo + } + if req.Params.LogitBias != nil { + attrs[schemas.AttrLogitBias] = fmt.Sprintf("%v", req.Params.LogitBias) + } + if req.Params.LogProbs != nil { + attrs[schemas.AttrLogProbs] = *req.Params.LogProbs + } + if req.Params.N != nil { + attrs[schemas.AttrN] = *req.Params.N + } + if req.Params.Seed != nil { + attrs[schemas.AttrSeed] = *req.Params.Seed + } + if req.Params.Suffix != nil { + attrs[schemas.AttrSuffix] = *req.Params.Suffix + } + if req.Params.User != nil { + attrs[schemas.AttrRequestUser] = *req.Params.User + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } + } + + // Extract input text + if req.Input != nil { + if req.Input.PromptStr != nil { + attrs[schemas.AttrInputText] = *req.Input.PromptStr + } else if req.Input.PromptArray != nil { + attrs[schemas.AttrInputText] = strings.Join(req.Input.PromptArray, ",") + } + } +} + +// PopulateTextCompletionResponseAttributes extracts text completion response attributes. +func PopulateTextCompletionResponseAttributes(resp *schemas.BifrostTextCompletionResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrResponseID] = resp.ID + attrs[schemas.AttrResponseModel] = resp.Model + if resp.Object != "" { + attrs[schemas.AttrObject] = resp.Object + } + if resp.SystemFingerprint != "" { + attrs[schemas.AttrSystemFprint] = resp.SystemFingerprint + } + + // Extract output text + var outputs []string + for _, choice := range resp.Choices { + if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { + outputs = append(outputs, *choice.TextCompletionResponseChoice.Text) + } + } + if len(outputs) > 0 { + attrs[schemas.AttrOutputMessages] = outputs + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrPromptTokens] = resp.Usage.PromptTokens + attrs[schemas.AttrCompletionTokens] = resp.Usage.CompletionTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Embedding Request/Response +// =============================================== + +// PopulateEmbeddingRequestAttributes extracts embedding request attributes. +func PopulateEmbeddingRequestAttributes(req *schemas.BifrostEmbeddingRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.Dimensions != nil { + attrs[schemas.AttrDimensions] = *req.Params.Dimensions + } + if req.Params.EncodingFormat != nil { + attrs[schemas.AttrEncodingFormat] = *req.Params.EncodingFormat + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } + } + + // Extract input + if req.Input != nil { + if req.Input.Text != nil { + attrs[schemas.AttrInputText] = *req.Input.Text + } else if req.Input.Texts != nil { + attrs[schemas.AttrInputText] = strings.Join(req.Input.Texts, ",") + } else if req.Input.Embedding != nil { + embedding := make([]string, len(req.Input.Embedding)) + for i, v := range req.Input.Embedding { + // Use a float‑safe representation; adjust precision as needed. + embedding[i] = fmt.Sprintf("%g", v) + } + attrs[schemas.AttrInputEmbedding] = strings.Join(embedding, ",") + } + } +} + +// PopulateEmbeddingResponseAttributes extracts embedding response attributes. +func PopulateEmbeddingResponseAttributes(resp *schemas.BifrostEmbeddingResponse, attrs map[string]any) { + if resp == nil { + return + } + // Usage + if resp.Usage != nil { + attrs[schemas.AttrPromptTokens] = resp.Usage.PromptTokens + attrs[schemas.AttrCompletionTokens] = resp.Usage.CompletionTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Transcription Request/Response +// =============================================== + +// PopulateTranscriptionRequestAttributes extracts transcription request attributes. +func PopulateTranscriptionRequestAttributes(req *schemas.BifrostTranscriptionRequest, attrs map[string]any) { + if req == nil || req.Params == nil { + return + } + + if req.Params.Language != nil { + attrs[schemas.AttrLanguage] = *req.Params.Language + } + if req.Params.Prompt != nil { + attrs[schemas.AttrPrompt] = *req.Params.Prompt + } + if req.Params.ResponseFormat != nil { + attrs[schemas.AttrResponseFormat] = *req.Params.ResponseFormat + } + if req.Params.Format != nil { + attrs[schemas.AttrFormat] = *req.Params.Format + } +} + +// PopulateTranscriptionResponseAttributes extracts transcription response attributes. +func PopulateTranscriptionResponseAttributes(resp *schemas.BifrostTranscriptionResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrOutputMessages] = resp.Text + + // Usage + if resp.Usage != nil { + if resp.Usage.InputTokens != nil { + attrs[schemas.AttrInputTokens] = *resp.Usage.InputTokens + } + if resp.Usage.OutputTokens != nil { + attrs[schemas.AttrOutputTokens] = *resp.Usage.OutputTokens + } + if resp.Usage.TotalTokens != nil { + attrs[schemas.AttrTotalTokens] = *resp.Usage.TotalTokens + } + if resp.Usage.InputTokenDetails != nil { + attrs[schemas.AttrInputTokenDetailsText] = resp.Usage.InputTokenDetails.TextTokens + attrs[schemas.AttrInputTokenDetailsAudio] = resp.Usage.InputTokenDetails.AudioTokens + } + } +} + +// =============================================== +// Speech Request/Response +// =============================================== + +// PopulateSpeechRequestAttributes extracts speech request attributes. +func PopulateSpeechRequestAttributes(req *schemas.BifrostSpeechRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.VoiceConfig != nil { + if req.Params.VoiceConfig.Voice != nil { + attrs[schemas.AttrVoice] = *req.Params.VoiceConfig.Voice + } + if len(req.Params.VoiceConfig.MultiVoiceConfig) > 0 { + voices := make([]string, len(req.Params.VoiceConfig.MultiVoiceConfig)) + for i, vc := range req.Params.VoiceConfig.MultiVoiceConfig { + voices[i] = vc.Voice + } + attrs[schemas.AttrMultiVoiceConfig] = strings.Join(voices, ",") + } + } + if req.Params.Instructions != "" { + attrs[schemas.AttrInstructions] = req.Params.Instructions + } + if req.Params.ResponseFormat != "" { + attrs[schemas.AttrResponseFormat] = req.Params.ResponseFormat + } + if req.Params.Speed != nil { + attrs[schemas.AttrSpeed] = *req.Params.Speed + } + } + + if req.Input != nil && req.Input.Input != "" { + attrs[schemas.AttrInputSpeech] = req.Input.Input + } +} + +// PopulateSpeechResponseAttributes extracts speech response attributes. +func PopulateSpeechResponseAttributes(resp *schemas.BifrostSpeechResponse, attrs map[string]any) { + if resp == nil { + return + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrInputTokens] = resp.Usage.InputTokens + attrs[schemas.AttrOutputTokens] = resp.Usage.OutputTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Responses API Request/Response +// =============================================== + +// PopulateResponsesRequestAttributes extracts responses API request attributes. +func PopulateResponsesRequestAttributes(req *schemas.BifrostResponsesRequest, attrs map[string]any) { + if req == nil || req.Params == nil { + return + } + + if req.Params.ParallelToolCalls != nil { + attrs[schemas.AttrParallelToolCall] = *req.Params.ParallelToolCalls + } + if req.Params.PromptCacheKey != nil { + attrs[schemas.AttrPromptCacheKey] = *req.Params.PromptCacheKey + } + if req.Params.Reasoning != nil { + if req.Params.Reasoning.Effort != nil { + attrs[schemas.AttrReasoningEffort] = *req.Params.Reasoning.Effort + } + if req.Params.Reasoning.Summary != nil { + attrs[schemas.AttrReasoningSummary] = *req.Params.Reasoning.Summary + } + if req.Params.Reasoning.GenerateSummary != nil { + attrs[schemas.AttrReasoningGenSummary] = *req.Params.Reasoning.GenerateSummary + } + } + if req.Params.SafetyIdentifier != nil { + attrs[schemas.AttrSafetyIdentifier] = *req.Params.SafetyIdentifier + } + if req.Params.ServiceTier != nil { + attrs[schemas.AttrServiceTier] = *req.Params.ServiceTier + } + if req.Params.Store != nil { + attrs[schemas.AttrStore] = *req.Params.Store + } + if req.Params.Temperature != nil { + attrs[schemas.AttrTemperature] = *req.Params.Temperature + } + if req.Params.Text != nil { + if req.Params.Text.Verbosity != nil { + attrs[schemas.AttrTextVerbosity] = *req.Params.Text.Verbosity + } + if req.Params.Text.Format != nil { + attrs[schemas.AttrTextFormatType] = req.Params.Text.Format.Type + } + } + if req.Params.TopLogProbs != nil { + attrs[schemas.AttrTopLogProbs] = *req.Params.TopLogProbs + } + if req.Params.TopP != nil { + attrs[schemas.AttrTopP] = *req.Params.TopP + } + if req.Params.ToolChoice != nil { + if req.Params.ToolChoice.ResponsesToolChoiceStr != nil && *req.Params.ToolChoice.ResponsesToolChoiceStr != "" { + attrs[schemas.AttrToolChoiceType] = *req.Params.ToolChoice.ResponsesToolChoiceStr + } + if req.Params.ToolChoice.ResponsesToolChoiceStruct != nil && req.Params.ToolChoice.ResponsesToolChoiceStruct.Name != nil { + attrs[schemas.AttrToolChoiceName] = *req.Params.ToolChoice.ResponsesToolChoiceStruct.Name + } + } + if req.Params.Tools != nil { + tools := make([]string, len(req.Params.Tools)) + for i, tool := range req.Params.Tools { + tools[i] = string(tool.Type) + } + attrs[schemas.AttrTools] = strings.Join(tools, ",") + } + if req.Params.Truncation != nil { + attrs[schemas.AttrTruncation] = *req.Params.Truncation + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateResponsesResponseAttributes extracts responses API response attributes. +func PopulateResponsesResponseAttributes(resp *schemas.BifrostResponsesResponse, attrs map[string]any) { + if resp == nil { + return + } + + if resp.ID != nil && *resp.ID != "" { + attrs[schemas.AttrResponseID] = *resp.ID + } + if resp.Model != "" { + attrs[schemas.AttrResponseModel] = resp.Model + } + if resp.ServiceTier != nil { + attrs[schemas.AttrServiceTier] = *resp.ServiceTier + } + + // Extract output messages (includes reasoning) + outputMessages := extractResponsesOutputMessages(resp) + if len(outputMessages) > 0 { + attrs[schemas.AttrOutputMessages] = outputMessages + } + + // Additional response fields + if resp.Include != nil { + attrs[schemas.AttrRespInclude] = strings.Join(resp.Include, ",") + } + if resp.MaxOutputTokens != nil { + attrs[schemas.AttrRespMaxOutputTokens] = *resp.MaxOutputTokens + } + if resp.MaxToolCalls != nil { + attrs[schemas.AttrRespMaxToolCalls] = *resp.MaxToolCalls + } + if resp.Metadata != nil { + attrs[schemas.AttrRespMetadata] = fmt.Sprintf("%v", resp.Metadata) + } + if resp.PreviousResponseID != nil { + attrs[schemas.AttrRespPreviousRespID] = *resp.PreviousResponseID + } + if resp.PromptCacheKey != nil { + attrs[schemas.AttrRespPromptCacheKey] = *resp.PromptCacheKey + } + if resp.Reasoning != nil { + if resp.Reasoning.Summary != nil { + attrs[schemas.AttrRespReasoningText] = *resp.Reasoning.Summary + } + if resp.Reasoning.Effort != nil { + attrs[schemas.AttrRespReasoningEffort] = *resp.Reasoning.Effort + } + if resp.Reasoning.GenerateSummary != nil { + attrs[schemas.AttrRespReasoningGenSum] = *resp.Reasoning.GenerateSummary + } + } + if resp.SafetyIdentifier != nil { + attrs[schemas.AttrRespSafetyIdentifier] = *resp.SafetyIdentifier + } + if resp.Store != nil { + attrs[schemas.AttrRespStore] = *resp.Store + } + if resp.Temperature != nil { + attrs[schemas.AttrRespTemperature] = *resp.Temperature + } + if resp.Text != nil { + if resp.Text.Verbosity != nil { + attrs[schemas.AttrRespTextVerbosity] = *resp.Text.Verbosity + } + if resp.Text.Format != nil { + attrs[schemas.AttrRespTextFormatType] = resp.Text.Format.Type + } + } + if resp.TopLogProbs != nil { + attrs[schemas.AttrRespTopLogProbs] = *resp.TopLogProbs + } + if resp.TopP != nil { + attrs[schemas.AttrRespTopP] = *resp.TopP + } + if resp.ToolChoice != nil { + if resp.ToolChoice.ResponsesToolChoiceStr != nil { + attrs[schemas.AttrRespToolChoiceType] = *resp.ToolChoice.ResponsesToolChoiceStr + } + if resp.ToolChoice.ResponsesToolChoiceStruct != nil && resp.ToolChoice.ResponsesToolChoiceStruct.Name != nil { + attrs[schemas.AttrRespToolChoiceName] = *resp.ToolChoice.ResponsesToolChoiceStruct.Name + } + } + if resp.Truncation != nil { + attrs[schemas.AttrRespTruncation] = *resp.Truncation + } + if resp.Tools != nil { + tools := make([]string, len(resp.Tools)) + for i, tool := range resp.Tools { + tools[i] = string(tool.Type) + } + attrs[schemas.AttrRespTools] = strings.Join(tools, ",") + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrInputTokens] = resp.Usage.InputTokens + attrs[schemas.AttrOutputTokens] = resp.Usage.OutputTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Batch Operations Request/Response +// =============================================== + +// PopulateBatchCreateRequestAttributes extracts batch create request attributes. +func PopulateBatchCreateRequestAttributes(req *schemas.BifrostBatchCreateRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.InputFileID != "" { + attrs[schemas.AttrBatchInputFileID] = req.InputFileID + } + if req.Endpoint != "" { + attrs[schemas.AttrBatchEndpoint] = string(req.Endpoint) + } + if req.CompletionWindow != "" { + attrs[schemas.AttrBatchCompletionWin] = req.CompletionWindow + } + if len(req.Requests) > 0 { + attrs[schemas.AttrBatchRequestsCount] = len(req.Requests) + } + if len(req.Metadata) > 0 { + attrs[schemas.AttrBatchMetadata] = fmt.Sprintf("%v", req.Metadata) + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchListRequestAttributes extracts batch list request attributes. +func PopulateBatchListRequestAttributes(req *schemas.BifrostBatchListRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Limit > 0 { + attrs[schemas.AttrBatchLimit] = req.Limit + } + if req.After != nil { + attrs[schemas.AttrBatchAfter] = *req.After + } + if req.BeforeID != nil { + attrs[schemas.AttrBatchBeforeID] = *req.BeforeID + } + if req.AfterID != nil { + attrs[schemas.AttrBatchAfterID] = *req.AfterID + } + if req.PageToken != nil { + attrs[schemas.AttrBatchPageToken] = *req.PageToken + } + if req.PageSize > 0 { + attrs[schemas.AttrBatchPageSize] = req.PageSize + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchRetrieveRequestAttributes extracts batch retrieve request attributes. +func PopulateBatchRetrieveRequestAttributes(req *schemas.BifrostBatchRetrieveRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.BatchID != "" { + attrs[schemas.AttrBatchID] = req.BatchID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchCancelRequestAttributes extracts batch cancel request attributes. +func PopulateBatchCancelRequestAttributes(req *schemas.BifrostBatchCancelRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.BatchID != "" { + attrs[schemas.AttrBatchID] = req.BatchID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchResultsRequestAttributes extracts batch results request attributes. +func PopulateBatchResultsRequestAttributes(req *schemas.BifrostBatchResultsRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.BatchID != "" { + attrs[schemas.AttrBatchID] = req.BatchID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchCreateResponseAttributes extracts batch create response attributes. +func PopulateBatchCreateResponseAttributes(resp *schemas.BifrostBatchCreateResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.ID + attrs[schemas.AttrBatchStatus] = string(resp.Status) + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + if resp.Endpoint != "" { + attrs[schemas.AttrBatchEndpoint] = resp.Endpoint + } + if resp.InputFileID != "" { + attrs[schemas.AttrBatchInputFileID] = resp.InputFileID + } + if resp.CompletionWindow != "" { + attrs[schemas.AttrBatchCompletionWin] = resp.CompletionWindow + } + if resp.CreatedAt != 0 { + attrs[schemas.AttrBatchCreatedAt] = resp.CreatedAt + } + if resp.ExpiresAt != nil { + attrs[schemas.AttrBatchExpiresAt] = *resp.ExpiresAt + } + if resp.OutputFileID != nil { + attrs[schemas.AttrBatchOutputFileID] = *resp.OutputFileID + } + if resp.ErrorFileID != nil { + attrs[schemas.AttrBatchErrorFileID] = *resp.ErrorFileID + } + attrs[schemas.AttrBatchCountTotal] = resp.RequestCounts.Total + attrs[schemas.AttrBatchCountCompleted] = resp.RequestCounts.Completed + attrs[schemas.AttrBatchCountFailed] = resp.RequestCounts.Failed +} + +// PopulateBatchListResponseAttributes extracts batch list response attributes. +func PopulateBatchListResponseAttributes(resp *schemas.BifrostBatchListResponse, attrs map[string]any) { + if resp == nil { + return + } + + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + attrs[schemas.AttrBatchDataCount] = len(resp.Data) + attrs[schemas.AttrBatchHasMore] = resp.HasMore + if resp.FirstID != nil { + attrs[schemas.AttrBatchFirstID] = *resp.FirstID + } + if resp.LastID != nil { + attrs[schemas.AttrBatchLastID] = *resp.LastID + } +} + +// PopulateBatchRetrieveResponseAttributes extracts batch retrieve response attributes. +func PopulateBatchRetrieveResponseAttributes(resp *schemas.BifrostBatchRetrieveResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.ID + attrs[schemas.AttrBatchStatus] = string(resp.Status) + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + if resp.Endpoint != "" { + attrs[schemas.AttrBatchEndpoint] = resp.Endpoint + } + if resp.InputFileID != "" { + attrs[schemas.AttrBatchInputFileID] = resp.InputFileID + } + if resp.CompletionWindow != "" { + attrs[schemas.AttrBatchCompletionWin] = resp.CompletionWindow + } + if resp.CreatedAt != 0 { + attrs[schemas.AttrBatchCreatedAt] = resp.CreatedAt + } + if resp.ExpiresAt != nil { + attrs[schemas.AttrBatchExpiresAt] = *resp.ExpiresAt + } + if resp.InProgressAt != nil { + attrs[schemas.AttrBatchInProgressAt] = *resp.InProgressAt + } + if resp.FinalizingAt != nil { + attrs[schemas.AttrBatchFinalizingAt] = *resp.FinalizingAt + } + if resp.CompletedAt != nil { + attrs[schemas.AttrBatchCompletedAt] = *resp.CompletedAt + } + if resp.FailedAt != nil { + attrs[schemas.AttrBatchFailedAt] = *resp.FailedAt + } + if resp.ExpiredAt != nil { + attrs[schemas.AttrBatchExpiredAt] = *resp.ExpiredAt + } + if resp.CancellingAt != nil { + attrs[schemas.AttrBatchCancellingAt] = *resp.CancellingAt + } + if resp.CancelledAt != nil { + attrs[schemas.AttrBatchCancelledAt] = *resp.CancelledAt + } + if resp.OutputFileID != nil { + attrs[schemas.AttrBatchOutputFileID] = *resp.OutputFileID + } + if resp.ErrorFileID != nil { + attrs[schemas.AttrBatchErrorFileID] = *resp.ErrorFileID + } + attrs[schemas.AttrBatchCountTotal] = resp.RequestCounts.Total + attrs[schemas.AttrBatchCountCompleted] = resp.RequestCounts.Completed + attrs[schemas.AttrBatchCountFailed] = resp.RequestCounts.Failed +} + +// PopulateBatchCancelResponseAttributes extracts batch cancel response attributes. +func PopulateBatchCancelResponseAttributes(resp *schemas.BifrostBatchCancelResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.ID + attrs[schemas.AttrBatchStatus] = string(resp.Status) + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + if resp.CancellingAt != nil { + attrs[schemas.AttrBatchCancellingAt] = *resp.CancellingAt + } + if resp.CancelledAt != nil { + attrs[schemas.AttrBatchCancelledAt] = *resp.CancelledAt + } + attrs[schemas.AttrBatchCountTotal] = resp.RequestCounts.Total + attrs[schemas.AttrBatchCountCompleted] = resp.RequestCounts.Completed + attrs[schemas.AttrBatchCountFailed] = resp.RequestCounts.Failed +} + +// PopulateBatchResultsResponseAttributes extracts batch results response attributes. +func PopulateBatchResultsResponseAttributes(resp *schemas.BifrostBatchResultsResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.BatchID + attrs[schemas.AttrBatchResultsCount] = len(resp.Results) + attrs[schemas.AttrBatchHasMore] = resp.HasMore + if resp.NextCursor != nil { + attrs[schemas.AttrBatchNextCursor] = *resp.NextCursor + } +} + +// =============================================== +// File Operations Request/Response +// =============================================== + +// PopulateFileUploadRequestAttributes extracts file upload request attributes. +func PopulateFileUploadRequestAttributes(req *schemas.BifrostFileUploadRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Filename != "" { + attrs[schemas.AttrFileFilename] = req.Filename + } + if req.Purpose != "" { + attrs[schemas.AttrFilePurpose] = string(req.Purpose) + } + if len(req.File) > 0 { + attrs[schemas.AttrFileBytes] = len(req.File) + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileListRequestAttributes extracts file list request attributes. +func PopulateFileListRequestAttributes(req *schemas.BifrostFileListRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Purpose != "" { + attrs[schemas.AttrFilePurpose] = string(req.Purpose) + } + if req.Limit > 0 { + attrs[schemas.AttrFileLimit] = req.Limit + } + if req.After != nil { + attrs[schemas.AttrFileAfter] = *req.After + } + if req.Order != nil { + attrs[schemas.AttrFileOrder] = *req.Order + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileRetrieveRequestAttributes extracts file retrieve request attributes. +func PopulateFileRetrieveRequestAttributes(req *schemas.BifrostFileRetrieveRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.FileID != "" { + attrs[schemas.AttrFileID] = req.FileID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileDeleteRequestAttributes extracts file delete request attributes. +func PopulateFileDeleteRequestAttributes(req *schemas.BifrostFileDeleteRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.FileID != "" { + attrs[schemas.AttrFileID] = req.FileID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileContentRequestAttributes extracts file content request attributes. +func PopulateFileContentRequestAttributes(req *schemas.BifrostFileContentRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.FileID != "" { + attrs[schemas.AttrFileID] = req.FileID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileUploadResponseAttributes extracts file upload response attributes. +func PopulateFileUploadResponseAttributes(resp *schemas.BifrostFileUploadResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.ID + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileBytes] = resp.Bytes + attrs[schemas.AttrFileCreatedAt] = resp.CreatedAt + attrs[schemas.AttrFileFilename] = resp.Filename + attrs[schemas.AttrFilePurpose] = string(resp.Purpose) + if resp.Status != "" { + attrs[schemas.AttrFileStatus] = string(resp.Status) + } + if resp.StorageBackend != "" { + attrs[schemas.AttrFileStorageBackend] = string(resp.StorageBackend) + } +} + +// PopulateFileListResponseAttributes extracts file list response attributes. +func PopulateFileListResponseAttributes(resp *schemas.BifrostFileListResponse, attrs map[string]any) { + if resp == nil { + return + } + + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileDataCount] = len(resp.Data) + attrs[schemas.AttrFileHasMore] = resp.HasMore +} + +// PopulateFileRetrieveResponseAttributes extracts file retrieve response attributes. +func PopulateFileRetrieveResponseAttributes(resp *schemas.BifrostFileRetrieveResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.ID + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileBytes] = resp.Bytes + attrs[schemas.AttrFileCreatedAt] = resp.CreatedAt + attrs[schemas.AttrFileFilename] = resp.Filename + attrs[schemas.AttrFilePurpose] = string(resp.Purpose) + if resp.Status != "" { + attrs[schemas.AttrFileStatus] = string(resp.Status) + } + if resp.StorageBackend != "" { + attrs[schemas.AttrFileStorageBackend] = string(resp.StorageBackend) + } +} + +// PopulateFileDeleteResponseAttributes extracts file delete response attributes. +func PopulateFileDeleteResponseAttributes(resp *schemas.BifrostFileDeleteResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.ID + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileDeleted] = resp.Deleted +} + +// PopulateFileContentResponseAttributes extracts file content response attributes. +func PopulateFileContentResponseAttributes(resp *schemas.BifrostFileContentResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.FileID + if resp.ContentType != "" { + attrs[schemas.AttrFileContentType] = resp.ContentType + } + if len(resp.Content) > 0 { + attrs[schemas.AttrFileContentBytes] = len(resp.Content) + } +} + +// =============================================== +// Helper functions for extracting messages +// =============================================== + +// MessageSummary represents a summarized chat message for tracing +type MessageSummary struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCallSummary `json:"tool_calls,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + ReasoningDetails []ReasoningDetailSummary `json:"reasoning_details,omitempty"` + Audio *AudioSummary `json:"audio,omitempty"` + Refusal string `json:"refusal,omitempty"` +} + +// ToolCallSummary represents a summarized tool call for tracing +type ToolCallSummary struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + Args string `json:"args,omitempty"` +} + +// ReasoningDetailSummary represents a summarized reasoning detail for tracing +type ReasoningDetailSummary struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// AudioSummary represents summarized audio data for tracing +type AudioSummary struct { + ID string `json:"id,omitempty"` + Transcript string `json:"transcript,omitempty"` +} + +// extractChatMessages extracts chat messages into a slice of MessageSummary +func extractChatMessages(messages []schemas.ChatMessage) []MessageSummary { + result := make([]MessageSummary, 0, len(messages)) + for _, msg := range messages { + summary := extractMessageSummary(&msg) + result = append(result, summary) + } + return result +} + +// extractChatResponseMessages extracts output messages from chat response +func extractChatResponseMessages(resp *schemas.BifrostChatResponse) []MessageSummary { + if resp == nil { + return nil + } + + result := make([]MessageSummary, 0, len(resp.Choices)) + for _, choice := range resp.Choices { + if choice.ChatNonStreamResponseChoice == nil || choice.ChatNonStreamResponseChoice.Message == nil { + continue + } + msg := choice.ChatNonStreamResponseChoice.Message + summary := extractMessageSummary(msg) + result = append(result, summary) + } + return result +} + +// extractMessageSummary extracts a full MessageSummary from a ChatMessage +func extractMessageSummary(msg *schemas.ChatMessage) MessageSummary { + if msg == nil { + return MessageSummary{} + } + + summary := MessageSummary{ + Role: string(schemas.ChatMessageRoleAssistant), + Content: extractMessageContent(msg.Content), + } + + if msg.Role != "" { + summary.Role = string(msg.Role) + } + + // Extract assistant-specific fields + if msg.ChatAssistantMessage != nil { + am := msg.ChatAssistantMessage + + // Extract refusal + if am.Refusal != nil && *am.Refusal != "" { + summary.Refusal = *am.Refusal + } + + // Extract reasoning + if am.Reasoning != nil && *am.Reasoning != "" { + summary.Reasoning = *am.Reasoning + } + + // Extract reasoning details + if len(am.ReasoningDetails) > 0 { + summary.ReasoningDetails = make([]ReasoningDetailSummary, 0, len(am.ReasoningDetails)) + for _, rd := range am.ReasoningDetails { + detail := ReasoningDetailSummary{ + Type: string(rd.Type), + } + if rd.Text != nil { + detail.Text = *rd.Text + } + summary.ReasoningDetails = append(summary.ReasoningDetails, detail) + } + } + + // Extract audio + if am.Audio != nil { + summary.Audio = &AudioSummary{ + ID: am.Audio.ID, + Transcript: am.Audio.Transcript, + } + } + + // Extract tool calls + if len(am.ToolCalls) > 0 { + summary.ToolCalls = make([]ToolCallSummary, 0, len(am.ToolCalls)) + for _, tc := range am.ToolCalls { + toolCall := ToolCallSummary{ + Type: "function", + } + if tc.ID != nil { + toolCall.ID = *tc.ID + } + if tc.Type != nil { + toolCall.Type = *tc.Type + } + if tc.Function.Name != nil { + toolCall.Name = *tc.Function.Name + } + toolCall.Args = tc.Function.Arguments + summary.ToolCalls = append(summary.ToolCalls, toolCall) + } + } + } + + return summary +} + +// ResponsesMessageSummary extends MessageSummary with reasoning +type ResponsesMessageSummary struct { + Role string `json:"role"` + Content string `json:"content"` + Reasoning string `json:"reasoning,omitempty"` +} + +// extractResponsesOutputMessages extracts output messages from responses API +func extractResponsesOutputMessages(resp *schemas.BifrostResponsesResponse) []ResponsesMessageSummary { + if resp == nil { + return nil + } + + result := make([]ResponsesMessageSummary, 0, len(resp.Output)) + for _, msg := range resp.Output { + if msg.Role == nil { + continue + } + content := "" + if msg.Content != nil { + if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { + content = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + content += *block.Text + } + } + } + } + // Extract reasoning text + reasoning := "" + if msg.ResponsesReasoning != nil && msg.ResponsesReasoning.Summary != nil { + for _, block := range msg.ResponsesReasoning.Summary { + if block.Text != "" { + reasoning += block.Text + } + } + } + result = append(result, ResponsesMessageSummary{ + Role: string(*msg.Role), + Content: content, + Reasoning: reasoning, + }) + } + return result +} + +// extractMessageContent extracts text content from ChatMessageContent +func extractMessageContent(content *schemas.ChatMessageContent) string { + if content == nil { + return "" + } + + if content.ContentStr != nil { + return *content.ContentStr + } + + if content.ContentBlocks != nil { + var builder strings.Builder + for _, block := range content.ContentBlocks { + if block.Text != nil { + builder.WriteString(*block.Text) + } + } + return builder.String() + } + + return "" +} + +// =============================================== +// Cost Calculation +// =============================================== + +// PopulateCostAttribute calculates and adds the cost attribute for a response. +// The pricingManager is optional; if nil, no cost attribute is added. +func PopulateCostAttribute( + resp *schemas.BifrostResponse, + pricingManager *modelcatalog.ModelCatalog, + attrs map[string]any, +) { + if pricingManager == nil || resp == nil { + return + } + cost := pricingManager.CalculateCostWithCacheDebug(resp) + attrs[schemas.AttrUsageCost] = cost +} diff --git a/framework/tracing/propagation.go b/framework/tracing/propagation.go new file mode 100644 index 000000000..0536b7a68 --- /dev/null +++ b/framework/tracing/propagation.go @@ -0,0 +1,192 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "strings" + + "github.com/valyala/fasthttp" +) + +// normalizeTraceID normalizes a trace ID to W3C-compliant format. +// Strips hyphens and ensures 32 lowercase hex characters. +// Returns empty string if input cannot be normalized to a valid trace ID. +func normalizeTraceID(traceID string) string { + // Remove hyphens (handles UUID format) + normalized := strings.ReplaceAll(traceID, "-", "") + normalized = strings.ToLower(normalized) + + // Validate length - must be exactly 32 hex chars + if len(normalized) != 32 { + return "" + } + + // Validate hex characters + if !isHex(normalized) { + return "" + } + + return normalized +} + +// normalizeSpanID normalizes a span ID to W3C-compliant format. +// Strips hyphens and ensures 16 lowercase hex characters. +// If input is longer (e.g., UUID format), takes first 16 hex chars. +// Returns empty string if input cannot be normalized to a valid span ID. +func normalizeSpanID(spanID string) string { + // Remove hyphens (handles UUID format) + normalized := strings.ReplaceAll(spanID, "-", "") + normalized = strings.ToLower(normalized) + + // If longer than 16 chars, truncate (e.g., full UUID -> first 16 hex chars) + if len(normalized) > 16 { + normalized = normalized[:16] + } + + // Validate length - must be exactly 16 hex chars + if len(normalized) != 16 { + return "" + } + + // Validate hex characters + if !isHex(normalized) { + return "" + } + + return normalized +} + +// W3C Trace Context header names +const ( + TraceParentHeader = "traceparent" + TraceStateHeader = "tracestate" +) + +// W3CTraceContext holds parsed W3C trace context values +type W3CTraceContext struct { + TraceID string // 32 hex characters + ParentID string // 16 hex characters (span ID of parent) + TraceFlags string // 2 hex characters + TraceState string // Optional vendor-specific trace state +} + +// ExtractParentID extracts the parent trace ID from W3C traceparent header +// Returns empty string if header is not present or invalid +func ExtractParentID(header *fasthttp.RequestHeader) string { + traceParent := string(header.Peek(TraceParentHeader)) + if traceParent == "" { + return "" + } + ctx := ParseTraceparent(traceParent) + if ctx == nil { + return "" + } + return ctx.TraceID +} + +// ExtractTraceContext extracts full W3C trace context from headers +func ExtractTraceContext(header *fasthttp.RequestHeader) *W3CTraceContext { + traceparent := string(header.Peek(TraceParentHeader)) + if traceparent == "" { + return nil + } + + ctx := ParseTraceparent(traceparent) + if ctx == nil { + return nil + } + + // Also extract tracestate if present + ctx.TraceState = string(header.Peek(TraceStateHeader)) + + return ctx +} + +// ParseTraceparent parses a W3C traceparent header value +// Format: version-traceid-parentid-traceflags +// Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 +func ParseTraceparent(traceparent string) *W3CTraceContext { + parts := strings.Split(traceparent, "-") + if len(parts) != 4 { + return nil + } + + version := parts[0] + traceID := parts[1] + parentID := parts[2] + traceFlags := parts[3] + + // Validate version (only 00 is currently supported) + if version != "00" { + return nil + } + + // Validate trace ID (32 hex characters) + if len(traceID) != 32 || !isHex(traceID) { + return nil + } + + // Validate parent ID (16 hex characters) + if len(parentID) != 16 || !isHex(parentID) { + return nil + } + + // Validate trace flags (2 hex characters) + if len(traceFlags) != 2 || !isHex(traceFlags) { + return nil + } + + return &W3CTraceContext{ + TraceID: traceID, + ParentID: parentID, + TraceFlags: traceFlags, + } +} + +// FormatTraceparent formats a W3C traceparent header value. +// It normalizes trace ID and span ID to W3C-compliant format: +// - trace ID: 32 lowercase hex characters +// - span ID: 16 lowercase hex characters +// Returns empty string if IDs cannot be normalized to valid format. +func FormatTraceparent(traceID, spanID, traceFlags string) string { + normalizedTraceID := normalizeTraceID(traceID) + normalizedSpanID := normalizeSpanID(spanID) + + if normalizedTraceID == "" || normalizedSpanID == "" { + return "" + } + + // Normalize and validate traceFlags + traceFlags = strings.ToLower(traceFlags) + if len(traceFlags) != 2 || !isHex(traceFlags) { + traceFlags = "00" // Default: not sampled + } + + return "00-" + normalizedTraceID + "-" + normalizedSpanID + "-" + traceFlags +} + +// InjectTraceContext injects W3C trace context headers into outgoing request +func InjectTraceContext(header *fasthttp.RequestHeader, traceID, spanID, traceFlags, traceState string) { + if traceID == "" || spanID == "" { + return + } + + traceparent := FormatTraceparent(traceID, spanID, traceFlags) + if traceparent == "" { + return // IDs could not be normalized to valid W3C format + } + header.Set(TraceParentHeader, traceparent) + + if traceState != "" { + header.Set(TraceStateHeader, traceState) + } +} + +// isHex checks if a string contains only hexadecimal characters +func isHex(s string) bool { + for _, c := range s { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) { + return false + } + } + return true +} diff --git a/framework/tracing/store.go b/framework/tracing/store.go new file mode 100644 index 000000000..9d7c84e3e --- /dev/null +++ b/framework/tracing/store.go @@ -0,0 +1,385 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "encoding/hex" + "sync" + "time" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" +) + +// DeferredSpanInfo stores information about a deferred span for streaming requests +type DeferredSpanInfo struct { + SpanID string + StartTime time.Time + Tracer schemas.Tracer // Reference to tracer for completing the span + RequestID string // Request ID for accumulator lookup + FirstChunkTime time.Time // Timestamp of first chunk (for TTFT calculation) + AccumulatedChunks []*schemas.BifrostResponse // Accumulated streaming chunks + mu sync.Mutex // Mutex for thread-safe chunk accumulation +} + +// TraceStore manages traces with thread-safe access and object pooling +type TraceStore struct { + traces sync.Map // map[traceID]*schemas.Trace - thread-safe concurrent access + deferredSpans sync.Map // map[traceID]*DeferredSpanInfo - deferred spans for streaming requests + tracePool sync.Pool // Reuse Trace objects to reduce allocations + spanPool sync.Pool // Reuse Span objects to reduce allocations + logger schemas.Logger + + ttl time.Duration + cleanupTicker *time.Ticker + stopCleanup chan struct{} + cleanupWg sync.WaitGroup + stopOnce sync.Once // Ensures Stop() cleanup runs only once +} + +// NewTraceStore creates a new TraceStore with the given TTL for cleanup +func NewTraceStore(ttl time.Duration, logger schemas.Logger) *TraceStore { + store := &TraceStore{ + ttl: ttl, + logger: logger, + tracePool: sync.Pool{ + New: func() any { + return &schemas.Trace{ + Spans: make([]*schemas.Span, 0, 16), // Pre-allocate capacity + Attributes: make(map[string]any), + } + }, + }, + spanPool: sync.Pool{ + New: func() any { + return &schemas.Span{ + Attributes: make(map[string]any), + Events: make([]schemas.SpanEvent, 0, 4), // Pre-allocate capacity + } + }, + }, + stopCleanup: make(chan struct{}), + } + + // Start background cleanup goroutine + store.startCleanup() + + return store +} + +// CreateTrace creates a new trace and stores it, returns trace ID only +func (s *TraceStore) CreateTrace(parentID string) string { + trace := s.tracePool.Get().(*schemas.Trace) + + // Reset and initialize the trace + trace.TraceID = generateTraceID() + trace.ParentID = parentID + trace.StartTime = time.Now() + trace.EndTime = time.Time{} + trace.RootSpan = nil + + // Reset slices but keep capacity + if trace.Spans != nil { + trace.Spans = trace.Spans[:0] + } else { + trace.Spans = make([]*schemas.Span, 0, 16) + } + + // Reset attributes + if trace.Attributes == nil { + trace.Attributes = make(map[string]any) + } else { + clear(trace.Attributes) + } + + s.traces.Store(trace.TraceID, trace) + return trace.TraceID +} + +// GetTrace retrieves a trace by ID +func (s *TraceStore) GetTrace(traceID string) *schemas.Trace { + if val, ok := s.traces.Load(traceID); ok { + return val.(*schemas.Trace) + } + return nil +} + +// CompleteTrace marks the trace as complete, removes it from store, and returns it for flushing +func (s *TraceStore) CompleteTrace(traceID string) *schemas.Trace { + // Clear any deferred span for this trace + s.deferredSpans.Delete(traceID) + + if val, ok := s.traces.LoadAndDelete(traceID); ok { + trace := val.(*schemas.Trace) + trace.EndTime = time.Now() + return trace + } + return nil +} + +// StoreDeferredSpan stores a span ID for later completion (used for streaming requests) +func (s *TraceStore) StoreDeferredSpan(traceID, spanID string) { + s.deferredSpans.Store(traceID, &DeferredSpanInfo{ + SpanID: spanID, + StartTime: time.Now(), + }) +} + +// GetDeferredSpan retrieves the deferred span info for a trace ID +func (s *TraceStore) GetDeferredSpan(traceID string) *DeferredSpanInfo { + if val, ok := s.deferredSpans.Load(traceID); ok { + return val.(*DeferredSpanInfo) + } + return nil +} + +// ClearDeferredSpan removes the deferred span info for a trace ID +func (s *TraceStore) ClearDeferredSpan(traceID string) { + s.deferredSpans.Delete(traceID) +} + +// AppendStreamingChunk adds a streaming chunk to the deferred span's accumulated data +func (s *TraceStore) AppendStreamingChunk(traceID string, chunk *schemas.BifrostResponse) { + if chunk == nil { + return + } + info := s.GetDeferredSpan(traceID) + if info == nil { + return + } + info.mu.Lock() + defer info.mu.Unlock() + + // Track first chunk time for TTFT calculation + if info.FirstChunkTime.IsZero() { + info.FirstChunkTime = time.Now() + } + + // Append chunk to accumulated list + info.AccumulatedChunks = append(info.AccumulatedChunks, chunk) +} + +// GetAccumulatedData returns the accumulated chunks and TTFT for a deferred span +func (s *TraceStore) GetAccumulatedData(traceID string) ([]*schemas.BifrostResponse, int64) { + info := s.GetDeferredSpan(traceID) + if info == nil { + return nil, 0 + } + info.mu.Lock() + defer info.mu.Unlock() + + // Calculate TTFT in milliseconds + var ttftMs int64 + if !info.StartTime.IsZero() && !info.FirstChunkTime.IsZero() { + ttftMs = info.FirstChunkTime.Sub(info.StartTime).Milliseconds() + } + + return info.AccumulatedChunks, ttftMs +} + +// ReleaseTrace returns the trace and its spans to the pools for reuse +func (s *TraceStore) ReleaseTrace(trace *schemas.Trace) { + if trace == nil { + return + } + + // Return all spans to the pool + for _, span := range trace.Spans { + s.releaseSpan(span) + } + + // Reset the trace + trace.Reset() + + // Return trace to pool + s.tracePool.Put(trace) +} + +// StartSpan creates a new span and adds it to the trace +func (s *TraceStore) StartSpan(traceID, name string, kind schemas.SpanKind) *schemas.Span { + trace := s.GetTrace(traceID) + if trace == nil { + return nil + } + + span := s.spanPool.Get().(*schemas.Span) + + // Reset and initialize the span + span.SpanID = generateSpanID() + span.TraceID = traceID + span.Name = name + span.Kind = kind + span.StartTime = time.Now() + span.EndTime = time.Time{} + span.Status = schemas.SpanStatusUnset + span.StatusMsg = "" + + // Reset slices but keep capacity + if span.Events != nil { + span.Events = span.Events[:0] + } else { + span.Events = make([]schemas.SpanEvent, 0, 4) + } + + // Reset attributes + if span.Attributes == nil { + span.Attributes = make(map[string]any) + } else { + clear(span.Attributes) + } + + // Set parent ID to root span if it exists, otherwise this is root + if trace.RootSpan != nil { + span.ParentID = trace.RootSpan.SpanID + } else { + span.ParentID = "" + trace.RootSpan = span + } + + // Add span to trace + trace.AddSpan(span) + + return span +} + +// StartChildSpan creates a new span as a child of the specified parent span +func (s *TraceStore) StartChildSpan(traceID, parentSpanID, name string, kind schemas.SpanKind) *schemas.Span { + trace := s.GetTrace(traceID) + if trace == nil { + return nil + } + + span := s.spanPool.Get().(*schemas.Span) + + // Reset and initialize the span + span.SpanID = generateSpanID() + span.ParentID = parentSpanID + span.TraceID = traceID + span.Name = name + span.Kind = kind + span.StartTime = time.Now() + span.EndTime = time.Time{} + span.Status = schemas.SpanStatusUnset + span.StatusMsg = "" + + // Reset slices but keep capacity + if span.Events != nil { + span.Events = span.Events[:0] + } else { + span.Events = make([]schemas.SpanEvent, 0, 4) + } + + // Reset attributes + if span.Attributes == nil { + span.Attributes = make(map[string]any) + } else { + clear(span.Attributes) + } + + // Add span to trace + trace.AddSpan(span) + + return span +} + +// EndSpan marks a span as complete with the given status and attributes +func (s *TraceStore) EndSpan(traceID, spanID string, status schemas.SpanStatus, statusMsg string, attrs map[string]any) { + trace := s.GetTrace(traceID) + if trace == nil { + return + } + + span := trace.GetSpan(spanID) + if span == nil { + return + } + + span.End(status, statusMsg) + + // Add any final attributes + for k, v := range attrs { + span.SetAttribute(k, v) + } +} + +// releaseSpan returns a span to the pool +func (s *TraceStore) releaseSpan(span *schemas.Span) { + if span == nil { + return + } + span.Reset() + s.spanPool.Put(span) +} + +// startCleanup starts the background cleanup goroutine +func (s *TraceStore) startCleanup() { + if s.ttl <= 0 { + return + } + + // Cleanup interval is TTL / 2 + cleanupInterval := s.ttl / 2 + if cleanupInterval < time.Minute { + cleanupInterval = time.Minute + } + + s.cleanupTicker = time.NewTicker(cleanupInterval) + s.cleanupWg.Add(1) + + go func() { + defer s.cleanupWg.Done() + for { + select { + case <-s.cleanupTicker.C: + s.cleanupOldTraces() + case <-s.stopCleanup: + return + } + } + }() +} + +// cleanupOldTraces removes traces that have exceeded the TTL +func (s *TraceStore) cleanupOldTraces() { + cutoff := time.Now().Add(-s.ttl) + count := 0 + + s.traces.Range(func(key, value any) bool { + trace := value.(*schemas.Trace) + if trace.StartTime.Before(cutoff) { + if deleted, ok := s.traces.LoadAndDelete(key); ok { + s.ReleaseTrace(deleted.(*schemas.Trace)) + count++ + } + } + return true + }) + + if count > 0 && s.logger != nil { + s.logger.Debug("tracing: cleaned up %d orphaned traces", count) + } +} + +// Stop stops the cleanup goroutine and releases resources +func (s *TraceStore) Stop() { + s.stopOnce.Do(func() { + if s.cleanupTicker != nil { + s.cleanupTicker.Stop() + } + close(s.stopCleanup) + s.cleanupWg.Wait() + }) +} + +// generateTraceID generates a W3C-compliant trace ID. +// Returns 32 lowercase hex characters (128-bit UUID without hyphens). +func generateTraceID() string { + u := uuid.New() + return hex.EncodeToString(u[:]) +} + +// generateSpanID generates a W3C-compliant span ID. +// Returns 16 lowercase hex characters (first 64 bits of a UUID). +func generateSpanID() string { + u := uuid.New() + return hex.EncodeToString(u[:8]) +} diff --git a/framework/tracing/tracer.go b/framework/tracing/tracer.go new file mode 100644 index 000000000..985bd552f --- /dev/null +++ b/framework/tracing/tracer.go @@ -0,0 +1,551 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "context" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" + "github.com/maximhq/bifrost/framework/streaming" +) + +// Tracer implements schemas.Tracer using TraceStore. +// It provides the bridge between the core Tracer interface and the +// framework's TraceStore implementation. +// It also embeds a streaming.Accumulator for centralized streaming chunk accumulation. +type Tracer struct { + store *TraceStore + accumulator *streaming.Accumulator +} + +// NewTracer creates a new Tracer wrapping the given TraceStore. +// The accumulator is embedded for centralized streaming chunk accumulation. +func NewTracer(store *TraceStore, pricingManager *modelcatalog.ModelCatalog, logger schemas.Logger) *Tracer { + return &Tracer{ + store: store, + accumulator: streaming.NewAccumulator(pricingManager, logger), + } +} + +// CreateTrace creates a new trace with optional parent ID and returns the trace ID. +func (t *Tracer) CreateTrace(parentID string) string { + return t.store.CreateTrace(parentID) +} + +// EndTrace completes a trace and returns the trace data for observation/export. +// The returned trace should be released after use by calling ReleaseTrace. +func (t *Tracer) EndTrace(traceID string) *schemas.Trace { + trace := t.store.CompleteTrace(traceID) + if trace == nil { + return nil + } + // Note: Caller is responsible for releasing the trace after plugin processing + // by calling ReleaseTrace on the store or letting GC handle it + return trace +} + +// ReleaseTrace returns the trace to the pool for reuse. +// Should be called after EndTrace when the trace data is no longer needed. +func (t *Tracer) ReleaseTrace(trace *schemas.Trace) { + t.store.ReleaseTrace(trace) +} + +// spanHandle is the concrete implementation of schemas.SpanHandle for Tracer. +// It contains the trace and span IDs needed to reference the span in the store. +type spanHandle struct { + traceID string + spanID string +} + +// StartSpan creates a new span as a child of the current span in context. +// It reads the trace ID and parent span ID from context, creates the span, +// and returns an updated context with the new span ID. +func (t *Tracer) StartSpan(ctx context.Context, name string, kind schemas.SpanKind) (context.Context, schemas.SpanHandle) { + traceID := GetTraceID(ctx) + if traceID == "" { + return ctx, nil + } + // Get parent span ID from context + parentSpanID, _ := ctx.Value(schemas.BifrostContextKeySpanID).(string) + var span *schemas.Span + if parentSpanID != "" { + span = t.store.StartChildSpan(traceID, parentSpanID, name, kind) + } else { + span = t.store.StartSpan(traceID, name, kind) + } + if span == nil { + return ctx, nil + } + // Update context with new span ID + newCtx := context.WithValue(ctx, schemas.BifrostContextKeySpanID, span.SpanID) + return newCtx, &spanHandle{traceID: traceID, spanID: span.SpanID} +} + +// EndSpan completes a span with the given status and message. +func (t *Tracer) EndSpan(handle schemas.SpanHandle, status schemas.SpanStatus, statusMsg string) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + t.store.EndSpan(h.traceID, h.spanID, status, statusMsg, nil) +} + +// SetAttribute sets an attribute on the span identified by the handle. +func (t *Tracer) SetAttribute(handle schemas.SpanHandle, key string, value any) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span != nil { + span.SetAttribute(key, value) + } +} + +// AddEvent adds a timestamped event to the span identified by the handle. +func (t *Tracer) AddEvent(handle schemas.SpanHandle, name string, attrs map[string]any) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span != nil { + span.AddEvent(schemas.SpanEvent{ + Name: name, + Timestamp: time.Now(), + Attributes: attrs, + }) + } +} + +// PopulateLLMRequestAttributes populates all LLM-specific request attributes on the span. +func (t *Tracer) PopulateLLMRequestAttributes(handle schemas.SpanHandle, req *schemas.BifrostRequest) { + h, ok := handle.(*spanHandle) + if !ok || h == nil || req == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span == nil { + return + } + + for k, v := range PopulateRequestAttributes(req) { + span.SetAttribute(k, v) + } +} + +// PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. +func (t *Tracer) PopulateLLMResponseAttributes(handle schemas.SpanHandle, resp *schemas.BifrostResponse, err *schemas.BifrostError) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span == nil { + return + } + for k, v := range PopulateResponseAttributes(resp) { + span.SetAttribute(k, v) + } + for k, v := range PopulateErrorAttributes(err) { + span.SetAttribute(k, v) + } +} + +// StoreDeferredSpan stores a span handle for later completion (used for streaming requests). +// The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. +func (t *Tracer) StoreDeferredSpan(traceID string, handle schemas.SpanHandle) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + t.store.StoreDeferredSpan(traceID, h.spanID) +} + +// GetDeferredSpanHandle retrieves a deferred span handle by trace ID. +// Returns nil if no deferred span exists for the given trace ID. +func (t *Tracer) GetDeferredSpanHandle(traceID string) schemas.SpanHandle { + info := t.store.GetDeferredSpan(traceID) + if info == nil { + return nil + } + return &spanHandle{traceID: traceID, spanID: info.SpanID} +} + +// ClearDeferredSpan removes the deferred span handle for a trace ID. +// Should be called after the deferred span has been completed. +func (t *Tracer) ClearDeferredSpan(traceID string) { + t.store.ClearDeferredSpan(traceID) +} + +// GetDeferredSpanID returns the span ID for the deferred span. +// Returns empty string if no deferred span exists. +func (t *Tracer) GetDeferredSpanID(traceID string) string { + info := t.store.GetDeferredSpan(traceID) + if info == nil { + return "" + } + return info.SpanID +} + +// AddStreamingChunk accumulates a streaming chunk for the deferred span. +// This stores the full BifrostResponse chunk for later reconstruction. +// Note: This method still uses the store for backward compatibility with existing code. +// For new code, prefer using ProcessStreamingChunk which uses the embedded accumulator. +func (t *Tracer) AddStreamingChunk(traceID string, response *schemas.BifrostResponse) { + if traceID == "" || response == nil { + return + } + t.store.AppendStreamingChunk(traceID, response) +} + +// GetAccumulatedChunks returns the accumulated BifrostResponse, TTFT, and chunk count for a deferred span. +// It reconstructs a complete response from all accumulated streaming chunks. +// Note: This method still uses the store for backward compatibility with existing code. +// For new code, prefer using ProcessStreamingChunk which uses the embedded accumulator. +func (t *Tracer) GetAccumulatedChunks(traceID string) (*schemas.BifrostResponse, int64, int) { + chunks, ttftMs := t.store.GetAccumulatedData(traceID) + if len(chunks) == 0 { + return nil, 0, 0 + } + + // Build complete response from accumulated chunks + return buildCompleteResponseFromChunks(chunks), ttftMs, len(chunks) +} + +// buildCompleteResponseFromChunks reconstructs a complete BifrostResponse from streaming chunks. +// This accumulates content, tool calls, reasoning, audio, and other fields. +// Note: This is kept for backward compatibility with existing code that uses AddStreamingChunk/GetAccumulatedChunks. +func buildCompleteResponseFromChunks(chunks []*schemas.BifrostResponse) *schemas.BifrostResponse { + if len(chunks) == 0 { + return nil + } + + // Use the last chunk as a base (it typically has final usage stats, finish reason, etc.) + lastChunk := chunks[len(chunks)-1] + if lastChunk.ChatResponse == nil { + return nil + } + + result := &schemas.BifrostResponse{ + ChatResponse: &schemas.BifrostChatResponse{ + ID: lastChunk.ChatResponse.ID, + Object: lastChunk.ChatResponse.Object, + Model: lastChunk.ChatResponse.Model, + Created: lastChunk.ChatResponse.Created, + Usage: lastChunk.ChatResponse.Usage, + ExtraFields: lastChunk.ChatResponse.ExtraFields, + Choices: make([]schemas.BifrostResponseChoice, 0), + }, + } + + // Track accumulated content per choice index + type choiceAccumulator struct { + content string + refusal string + reasoning string + reasoningDetails []schemas.ChatReasoningDetails + toolCalls map[int]schemas.ChatAssistantMessageToolCall // keyed by tool call index + audio *schemas.ChatAudioMessageAudio + role schemas.ChatMessageRole + finishReason *string + } + + choiceMap := make(map[int]*choiceAccumulator) + + // Process chunks in order + for _, chunk := range chunks { + if chunk.ChatResponse == nil { + continue + } + for _, choice := range chunk.ChatResponse.Choices { + if choice.ChatStreamResponseChoice == nil || choice.ChatStreamResponseChoice.Delta == nil { + continue + } + delta := choice.ChatStreamResponseChoice.Delta + idx := choice.Index + + // Get or create accumulator for this choice + acc, ok := choiceMap[idx] + if !ok { + acc = &choiceAccumulator{ + role: schemas.ChatMessageRoleAssistant, + toolCalls: make(map[int]schemas.ChatAssistantMessageToolCall), + } + choiceMap[idx] = acc + } + + // Accumulate content + if delta.Content != nil { + acc.content += *delta.Content + } + + // Role (usually in first chunk) + if delta.Role != nil { + acc.role = schemas.ChatMessageRole(*delta.Role) + } + + // Refusal + if delta.Refusal != nil { + acc.refusal += *delta.Refusal + } + + // Reasoning + if delta.Reasoning != nil { + acc.reasoning += *delta.Reasoning + } + + // Reasoning details (merge by index) + for _, rd := range delta.ReasoningDetails { + found := false + for i := range acc.reasoningDetails { + if acc.reasoningDetails[i].Index == rd.Index { + // Accumulate text + if rd.Text != nil { + if acc.reasoningDetails[i].Text == nil { + acc.reasoningDetails[i].Text = rd.Text + } else { + newText := *acc.reasoningDetails[i].Text + *rd.Text + acc.reasoningDetails[i].Text = &newText + } + } + // Update type if present + if rd.Type != "" { + acc.reasoningDetails[i].Type = rd.Type + } + found = true + break + } + } + if !found { + acc.reasoningDetails = append(acc.reasoningDetails, rd) + } + } + + // Audio + if delta.Audio != nil { + if acc.audio == nil { + acc.audio = &schemas.ChatAudioMessageAudio{ + ID: delta.Audio.ID, + Data: delta.Audio.Data, + ExpiresAt: delta.Audio.ExpiresAt, + Transcript: delta.Audio.Transcript, + } + } else { + acc.audio.Data += delta.Audio.Data + acc.audio.Transcript += delta.Audio.Transcript + if delta.Audio.ID != "" { + acc.audio.ID = delta.Audio.ID + } + if delta.Audio.ExpiresAt != 0 { + acc.audio.ExpiresAt = delta.Audio.ExpiresAt + } + } + } + + // Tool calls (merge by index) + for _, tc := range delta.ToolCalls { + tcIdx := int(tc.Index) + existing, ok := acc.toolCalls[tcIdx] + if !ok { + // New tool call + acc.toolCalls[tcIdx] = tc + } else { + // Merge: accumulate arguments, update other fields + if tc.ID != nil { + existing.ID = tc.ID + } + if tc.Type != nil { + existing.Type = tc.Type + } + if tc.Function.Name != nil { + existing.Function.Name = tc.Function.Name + } + existing.Function.Arguments += tc.Function.Arguments + acc.toolCalls[tcIdx] = existing + } + } + + // Finish reason (from BifrostResponseChoice, not ChatStreamResponseChoice) + if choice.FinishReason != nil { + acc.finishReason = choice.FinishReason + } + } + } + + // Build final choices from accumulated data + // Sort choice indices for deterministic output + choiceIndices := make([]int, 0, len(choiceMap)) + for idx := range choiceMap { + choiceIndices = append(choiceIndices, idx) + } + + for _, idx := range choiceIndices { + accum := choiceMap[idx] + + // Build message + msg := &schemas.ChatMessage{ + Role: accum.role, + } + + // Set content + if accum.content != "" { + msg.Content = &schemas.ChatMessageContent{ + ContentStr: &accum.content, + } + } + + // Build assistant message fields + if accum.refusal != "" || accum.reasoning != "" || len(accum.reasoningDetails) > 0 || + accum.audio != nil || len(accum.toolCalls) > 0 { + msg.ChatAssistantMessage = &schemas.ChatAssistantMessage{} + + if accum.refusal != "" { + msg.ChatAssistantMessage.Refusal = &accum.refusal + } + if accum.reasoning != "" { + msg.ChatAssistantMessage.Reasoning = &accum.reasoning + } + if len(accum.reasoningDetails) > 0 { + msg.ChatAssistantMessage.ReasoningDetails = accum.reasoningDetails + } + if accum.audio != nil { + msg.ChatAssistantMessage.Audio = accum.audio + } + if len(accum.toolCalls) > 0 { + // Sort tool calls by index + tcIndices := make([]int, 0, len(accum.toolCalls)) + for tcIdx := range accum.toolCalls { + tcIndices = append(tcIndices, tcIdx) + } + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0, len(accum.toolCalls)) + for _, tcIdx := range tcIndices { + toolCalls = append(toolCalls, accum.toolCalls[tcIdx]) + } + msg.ChatAssistantMessage.ToolCalls = toolCalls + } + } + + // Build choice + choice := schemas.BifrostResponseChoice{ + Index: idx, + FinishReason: accum.finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: msg, + }, + } + result.ChatResponse.Choices = append(result.ChatResponse.Choices, choice) + } + + return result +} + +// CreateStreamAccumulator creates a new stream accumulator for the given trace ID. +// This should be called at the start of a streaming request. +func (t *Tracer) CreateStreamAccumulator(traceID string, startTime time.Time) { + if traceID == "" || t.accumulator == nil { + return + } + t.accumulator.CreateStreamAccumulator(traceID, startTime) +} + +// CleanupStreamAccumulator removes the stream accumulator for the given trace ID. +// This should be called after the streaming request is complete. +func (t *Tracer) CleanupStreamAccumulator(traceID string) { + if traceID == "" || t.accumulator == nil { + return + } + _ = t.accumulator.CleanupStreamAccumulator(traceID) +} + +// ProcessStreamingChunk processes a streaming chunk and accumulates it. +// Returns the accumulated result. IsFinal will be true when the stream is complete. +// This method is used by plugins to access accumulated streaming data. +// The ctx parameter must contain the stream end indicator for proper final chunk detection. +func (t *Tracer) ProcessStreamingChunk(ctx *schemas.BifrostContext, traceID string, result *schemas.BifrostResponse, err *schemas.BifrostError) *schemas.StreamAccumulatorResult { + if traceID == "" || t.accumulator == nil || ctx == nil { + return nil + } + + // Create a new context that wraps the original but overrides requestID with traceID. + // BifrostContextKeyRequestID is a reserved key, so we can't use SetValue. + // Instead, we create a new parent context with the traceID as requestID, + // then create a BifrostContext that inherits the stream end indicator from the original. + parent := context.WithValue(ctx, schemas.BifrostContextKeyRequestID, traceID) + // Copy stream end indicator from the original context's parent (if available) + if streamEnd := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); streamEnd != nil { + parent = context.WithValue(parent, schemas.BifrostContextKeyStreamEndIndicator, streamEnd) + } + accumCtx := schemas.NewBifrostContext(parent, time.Time{}) + + processedResp, processErr := t.accumulator.ProcessStreamingResponse(accumCtx, result, err) + if processErr != nil || processedResp == nil { + return nil + } + + // Convert ProcessedStreamResponse to StreamAccumulatorResult + accResult := &schemas.StreamAccumulatorResult{ + IsFinal: processedResp.Type == streaming.StreamResponseTypeFinal, + RequestID: processedResp.RequestID, + Model: processedResp.Model, + Provider: processedResp.Provider, + } + + if processedResp.Data != nil { + accResult.Status = processedResp.Data.Status + accResult.Latency = processedResp.Data.Latency + accResult.TimeToFirstToken = processedResp.Data.TimeToFirstToken + accResult.OutputMessage = processedResp.Data.OutputMessage + accResult.OutputMessages = processedResp.Data.OutputMessages + accResult.TokenUsage = processedResp.Data.TokenUsage + accResult.Cost = processedResp.Data.Cost + accResult.ErrorDetails = processedResp.Data.ErrorDetails + accResult.AudioOutput = processedResp.Data.AudioOutput + accResult.TranscriptionOutput = processedResp.Data.TranscriptionOutput + accResult.FinishReason = processedResp.Data.FinishReason + accResult.RawResponse = processedResp.Data.RawResponse + } + + if processedResp.RawRequest != nil { + accResult.RawRequest = *processedResp.RawRequest + } + + return accResult +} + +// GetAccumulator returns the embedded streaming accumulator. +// This is useful for plugins that need direct access to accumulator methods. +func (t *Tracer) GetAccumulator() *streaming.Accumulator { + return t.accumulator +} + +// Stop stops the tracer and releases its resources. +// This stops the internal TraceStore's cleanup goroutine. +func (t *Tracer) Stop() { + if t.store != nil { + t.store.Stop() + } + if t.accumulator != nil { + t.accumulator.Cleanup() + } +} + +// Ensure Tracer implements schemas.Tracer at compile time +var _ schemas.Tracer = (*Tracer)(nil) diff --git a/framework/version b/framework/version index 76700a794..867e52437 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.1.53 \ No newline at end of file +1.2.0 \ No newline at end of file diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index d05a247e3..921beb3a8 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -1,3 +1,21 @@ - refactor: extracted governance store into an interface for extensibility - refactor: extended the way governance store handles rate limits -- chore: added e2e tests for governance plugin \ No newline at end of file +- chore: added e2e tests for governance plugin +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/governance/go.mod b/plugins/governance/go.mod index 51cf7dd5b..a862ae26e 100644 --- a/plugins/governance/go.mod +++ b/plugins/governance/go.mod @@ -7,7 +7,9 @@ require gorm.io/gorm v1.31.1 require ( github.com/maximhq/bifrost/core v1.2.42 github.com/maximhq/bifrost/framework v1.1.52 + github.com/bytedance/sonic v1.14.2 github.com/stretchr/testify v1.11.1 + github.com/valyala/fasthttp v1.68.0 ) require ( @@ -35,7 +37,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect - github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect @@ -89,7 +90,6 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.68.0 // indirect github.com/weaviate/weaviate v1.34.5 // indirect github.com/weaviate/weaviate-go-client/v5 v5.6.0 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect diff --git a/plugins/governance/go.sum b/plugins/governance/go.sum index c84f5eeaa..fdca0889e 100644 --- a/plugins/governance/go.sum +++ b/plugins/governance/go.sum @@ -166,10 +166,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.43 h1:NxtvzvLL0Isaf8mD1dlGb9JxT7/PZNPv5NBTnqHG100= -github.com/maximhq/bifrost/core v1.2.43/go.mod h1:1msCedjIgC8d9TNJyB1z7s+348vh2Bd0u66qAPgpZoA= -github.com/maximhq/bifrost/framework v1.1.53 h1:2mNTaO4TdUVM/Kye2hpdzCoFdzb9EeMDY8jSzD5zDYA= -github.com/maximhq/bifrost/framework v1.1.53/go.mod h1:tBFZc0wBZPeM+gu/j3XU/P1KhjaRK0/zjpPb+h4W9DQ= +github.com/maximhq/bifrost/core v1.2.42 h1:0G5TD4sZWlT8CwteobFpXmnALGzFQ6lsrDAQl8tr7/k= +github.com/maximhq/bifrost/framework v1.1.52 h1:n36FUjcnXoNQaVVYdkMcBMf6VnthloWaq1rFomdqVVA= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 4ff96dce0..5966e2c2d 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -10,11 +10,13 @@ import ( "strings" "sync" + "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/modelcatalog" + "github.com/valyala/fasthttp" ) // PluginName is the name of the governance plugin @@ -228,63 +230,87 @@ func (p *GovernancePlugin) GetName() string { return PluginName } -// TransportInterceptor intercepts requests before they are processed (governance decision point) -// Parameters: -// - ctx: The Bifrost context -// - url: The URL of the request -// - headers: The request headers -// - body: The request body -// -// Returns: -// - map[string]string: The updated request headers -// - map[string]any: The updated request body -// - error: Any error that occurred during processing -func (p *GovernancePlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { +func parseVirtualKey(ctx *fasthttp.RequestCtx) *string { var virtualKeyValue string - var err error - - for header, value := range headers { - headerStr := strings.ToLower(header) - if headerStr == string(schemas.BifrostContextKeyVirtualKey) { - virtualKeyValue = string(value) - break - } - if headerStr == "authorization" { - valueStr := string(value) - // Only accept Bearer token format: "Bearer ..." - if strings.HasPrefix(strings.ToLower(valueStr), "bearer ") { - authHeaderValue := strings.TrimSpace(valueStr[7:]) // Remove "Bearer " prefix - if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), VirtualKeyPrefix) { - virtualKeyValue = authHeaderValue - break - } + vkHeader := ctx.Request.Header.Peek("x-bf-vk") + if string(vkHeader) != "" { + return bifrost.Ptr(string(vkHeader)) + } + authHeader := string(ctx.Request.Header.Peek("Authorization")) + if authHeader != "" { + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix + if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), VirtualKeyPrefix) { + virtualKeyValue = authHeaderValue } } - if (headerStr == "x-api-key" || headerStr == "x-goog-api-key") && strings.HasPrefix(strings.ToLower(string(value)), VirtualKeyPrefix) { - virtualKeyValue = string(value) - break - } } - if virtualKeyValue == "" { - return headers, body, nil + if virtualKeyValue != "" { + return bifrost.Ptr(virtualKeyValue) } - - virtualKey, ok := p.store.GetVirtualKey(virtualKeyValue) - if !ok || virtualKey == nil || !virtualKey.IsActive { - return headers, body, nil + xAPIKey := string(ctx.Request.Header.Peek("x-api-key")) + if xAPIKey != "" && strings.HasPrefix(strings.ToLower(xAPIKey), VirtualKeyPrefix) { + return bifrost.Ptr(xAPIKey) } - - body, err = p.loadBalanceProvider(body, virtualKey) - if err != nil { - return headers, body, err + // Checking x-goog-api-key header + xGoogleAPIKey := string(ctx.Request.Header.Peek("x-goog-api-key")) + if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { + return bifrost.Ptr(xGoogleAPIKey) } + return nil +} - headers, err = p.addMCPIncludeTools(headers, virtualKey) - if err != nil { - return headers, body, err +// HTTPTransportMiddleware intercepts requests before they are processed (governance decision point) +func (p *GovernancePlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + virtualKeyValue := parseVirtualKey(ctx) + if virtualKeyValue == nil { + next(ctx) + return + } + // Get the virtual key from the store + virtualKey, ok := p.store.GetVirtualKey(*virtualKeyValue) + if !ok || virtualKey == nil || !virtualKey.IsActive { + next(ctx) + return + } + headers, err := p.addMCPIncludeTools(nil, virtualKey) + if err != nil { + p.logger.Error("failed to add MCP include tools: %v", err) + next(ctx) + return + } + for header, value := range headers { + ctx.Request.Header.Set(header, value) + } + if ctx.Request.Body() == nil { + next(ctx) + return + } + var payload map[string]any + err = sonic.Unmarshal(ctx.Request.Body(), &payload) + if err != nil { + p.logger.Error("failed to marshal request body to check for virtual key: %v", err) + next(ctx) + return + } + payload, err = p.loadBalanceProvider(payload, virtualKey) + if err != nil { + p.logger.Error("failed to load balance provider: %v", err) + next(ctx) + return + } + body, err := sonic.Marshal(payload) + if err != nil { + p.logger.Error("failed to marshal request body to check for virtual key: %v", err) + next(ctx) + return + } + ctx.Request.SetBody(body) + next(ctx) + } } - - return headers, body, nil } // loadBalanceProvider loads balances the provider for the request diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md index e69de29bb..218e0a55a 100644 --- a/plugins/jsonparser/changelog.md +++ b/plugins/jsonparser/changelog.md @@ -0,0 +1,18 @@ +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go index 7281790e9..d4dbcb5c0 100644 --- a/plugins/jsonparser/main.go +++ b/plugins/jsonparser/main.go @@ -83,24 +83,16 @@ func (p *JsonParserPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -// Parameters: -// - ctx: The Bifrost context -// - url: The URL of the request -// - headers: The request headers -// - body: The request body -// Returns: -// - map[string]string: The updated request headers -// - map[string]any: The updated request body -// - error: Any error that occurred during processing -func (p *JsonParserPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (p *JsonParserPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // PreHook is not used for this plugin as we only process responses // Parameters: // - ctx: The Bifrost context // - req: The Bifrost request +// // Returns: // - *schemas.BifrostRequest: The processed request // - *schemas.PluginShortCircuit: The plugin short circuit if the request is not allowed @@ -114,6 +106,7 @@ func (p *JsonParserPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif // - ctx: The Bifrost context // - result: The Bifrost response to be processed // - err: The Bifrost error to be processed +// // Returns: // - *schemas.BifrostResponse: The processed response // - *schemas.BifrostError: The processed error diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index e69de29bb..b53788e12 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -0,0 +1,19 @@ +- feat: logging now uses central accumulator vs its own; reducing total memory consumption during runtime +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 1e52b34e8..f2133e7d9 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -105,10 +105,9 @@ type LoggerPlugin struct { logger schemas.Logger logCallback LogCallback droppedRequests atomic.Int64 - cleanupTicker *time.Ticker // Ticker for cleaning up old processing logs - logMsgPool sync.Pool // Pool for reusing LogMessage structs - updateDataPool sync.Pool // Pool for reusing UpdateLogData structs - accumulator *streaming.Accumulator // Accumulator for streaming chunks + cleanupTicker *time.Ticker // Ticker for cleaning up old processing logs + logMsgPool sync.Pool // Pool for reusing LogMessage structs + updateDataPool sync.Pool // Pool for reusing UpdateLogData structs } // Init creates new logger plugin with given log store @@ -140,7 +139,6 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, logsStore return &UpdateLogData{} }, }, - accumulator: streaming.NewAccumulator(pricingManager, logger), } // Prewarm the pools for better performance at startup @@ -193,19 +191,9 @@ func (p *LoggerPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -// Parameters: -// - ctx: The Bifrost context -// - url: The URL of the request -// - headers: The request headers -// - body: The request body -// -// Returns: -// - map[string]string: The updated request headers -// - map[string]any: The updated request body -// - error: Any error that occurred during processing -func (p *LoggerPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (p *LoggerPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O @@ -234,9 +222,14 @@ func (p *LoggerPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bifrost createdTimestamp := time.Now().UTC() - // If request type is streaming we create a stream accumulator + // If request type is streaming we create a stream accumulator via the tracer if bifrost.IsStreamRequestType(req.RequestType) { - p.accumulator.CreateStreamAccumulator(requestID, createdTimestamp) + if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil { + // Use traceID for the central accumulator + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.CreateStreamAccumulator(traceID, createdTimestamp) + } + } } provider, model, _ := req.GetRequestFields() @@ -273,7 +266,7 @@ func (p *LoggerPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bifrost initialData.SpeechInput = req.SpeechRequest.Input case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: initialData.Params = req.TranscriptionRequest.Params - initialData.TranscriptionInput = req.TranscriptionRequest.Input + initialData.TranscriptionInput = req.TranscriptionRequest.Input } } @@ -388,10 +381,7 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif // If response is nil, and there is an error, we update log with error if result == nil && bifrostErr != nil { - // If request type is streaming, then we trigger cleanup as well - if bifrost.IsStreamRequestType(requestType) { - p.accumulator.CleanupStreamAccumulator(requestID) - } + // Note: Stream accumulator cleanup is handled by the tracing middleware logMsg.Operation = LogOperationUpdate logMsg.UpdateData = &UpdateLogData{ Status: "error", @@ -430,10 +420,22 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif if bifrost.IsStreamRequestType(requestType) { p.logger.Debug("[logging] processing streaming response") - streamResponse, err := p.accumulator.ProcessStreamingResponse(ctx, result, bifrostErr) - if err != nil { - p.logger.Debug("failed to process streaming response: %v", err) - } else if streamResponse != nil && streamResponse.Type == streaming.StreamResponseTypeFinal { + // Process streaming response via tracer's central accumulator + var streamResponse *streaming.ProcessedStreamResponse + if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + // Pass the context so the accumulator can detect final chunks via StreamEndIndicator + accResult := tracer.ProcessStreamingChunk(ctx, traceID, result, bifrostErr) + if accResult != nil { + // Convert StreamAccumulatorResult to ProcessedStreamResponse + streamResponse = convertToProcessedStreamResponse(accResult, requestType) + } + } + } + + if streamResponse == nil { + p.logger.Debug("failed to process streaming response: tracer or traceID not available") + } else if streamResponse.Type == streaming.StreamResponseTypeFinal { // Prepare final log data logMsg.Operation = LogOperationStreamUpdate logMsg.StreamResponse = streamResponse @@ -464,6 +466,7 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif } p.mu.Unlock() } + // Note: Stream accumulator cleanup is handled by the tracing middleware } } else { // Handle regular response @@ -618,7 +621,7 @@ func (p *LoggerPlugin) Cleanup() error { close(p.done) // Wait for the background worker to finish processing remaining items p.wg.Wait() - p.accumulator.Cleanup() + // Note: Accumulator cleanup is handled by the tracer, not the logging plugin // GORM handles connection cleanup automatically return nil } diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 93859bc7f..bc68f580a 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -10,6 +10,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/streaming" ) // KeyPair represents an ID-Name pair for keys @@ -224,3 +225,76 @@ func getIntFromContext(ctx context.Context, key any) int { } return 0 } + +// convertToProcessedStreamResponse converts a StreamAccumulatorResult to ProcessedStreamResponse +// for use with the logging plugin's streaming log update functionality. +func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, requestType schemas.RequestType) *streaming.ProcessedStreamResponse { + if result == nil { + return nil + } + + // Determine stream type from request type + var streamType streaming.StreamType + switch requestType { + case schemas.TextCompletionStreamRequest: + streamType = streaming.StreamTypeText + case schemas.ChatCompletionStreamRequest: + streamType = streaming.StreamTypeChat + case schemas.ResponsesStreamRequest: + streamType = streaming.StreamTypeResponses + case schemas.SpeechStreamRequest: + streamType = streaming.StreamTypeAudio + case schemas.TranscriptionStreamRequest: + streamType = streaming.StreamTypeTranscription + default: + streamType = streaming.StreamTypeChat + } + + // Determine response type + var responseType streaming.StreamResponseType + if result.IsFinal { + responseType = streaming.StreamResponseTypeFinal + } else { + responseType = streaming.StreamResponseTypeDelta + } + + // Build accumulated data + data := &streaming.AccumulatedData{ + RequestID: result.RequestID, + Model: result.Model, + Status: result.Status, + Stream: true, + Latency: result.Latency, + TimeToFirstToken: result.TimeToFirstToken, + OutputMessage: result.OutputMessage, + OutputMessages: result.OutputMessages, + ErrorDetails: result.ErrorDetails, + TokenUsage: result.TokenUsage, + Cost: result.Cost, + AudioOutput: result.AudioOutput, + TranscriptionOutput: result.TranscriptionOutput, + FinishReason: result.FinishReason, + RawResponse: result.RawResponse, + } + + // Handle tool calls if present + if result.OutputMessage != nil && result.OutputMessage.ChatAssistantMessage != nil { + data.ToolCalls = result.OutputMessage.ChatAssistantMessage.ToolCalls + } + + resp := &streaming.ProcessedStreamResponse{ + Type: responseType, + RequestID: result.RequestID, + StreamType: streamType, + Provider: result.Provider, + Model: result.Model, + Data: data, + } + + if result.RawRequest != nil { + rawReq := result.RawRequest + resp.RawRequest = &rawReq + } + + return resp +} diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md index e69de29bb..218e0a55a 100644 --- a/plugins/maxim/changelog.md +++ b/plugins/maxim/changelog.md @@ -0,0 +1,18 @@ +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index c78a49a81..410ba7bba 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -121,9 +121,9 @@ func (plugin *Plugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (plugin *Plugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (plugin *Plugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // getEffectiveLogRepoID determines which single log repo ID to use based on priority: diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md index e69de29bb..218e0a55a 100644 --- a/plugins/mocker/changelog.md +++ b/plugins/mocker/changelog.md @@ -0,0 +1,18 @@ +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index d15dfacdb..32cfdb0a6 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -478,9 +478,9 @@ func (p *MockerPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (p *MockerPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (p *MockerPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // PreHook intercepts requests and applies mocking rules based on configuration diff --git a/plugins/otel/changelog.md b/plugins/otel/changelog.md index e69de29bb..1e802fed8 100644 --- a/plugins/otel/changelog.md +++ b/plugins/otel/changelog.md @@ -0,0 +1,19 @@ +- feat: otel now uses central accumulator reducing the total amount of memory consumed in runtime +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/otel/converter.go b/plugins/otel/converter.go index d113b2bf5..9decd7c01 100644 --- a/plugins/otel/converter.go +++ b/plugins/otel/converter.go @@ -4,10 +4,8 @@ import ( "encoding/hex" "fmt" "strings" - "time" "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" commonpb "go.opentelemetry.io/proto/otlp/common/v1" resourcepb "go.opentelemetry.io/proto/otlp/resource/v1" tracepb "go.opentelemetry.io/proto/otlp/trace/v1" @@ -71,1021 +69,217 @@ func hexToBytes(hexStr string, length int) []byte { return bytes } -// getSpeechRequestParams handles the speech request -func getSpeechRequestParams(req *schemas.BifrostSpeechRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.VoiceConfig != nil { - if req.Params.VoiceConfig.Voice != nil { - params = append(params, kvStr("gen_ai.request.voice", *req.Params.VoiceConfig.Voice)) - } - if len(req.Params.VoiceConfig.MultiVoiceConfig) > 0 { - multiVoiceConfigParams := []*KeyValue{} - for _, voiceConfig := range req.Params.VoiceConfig.MultiVoiceConfig { - multiVoiceConfigParams = append(multiVoiceConfigParams, kvStr("gen_ai.request.voice", voiceConfig.Voice)) - } - params = append(params, kvAny("gen_ai.request.multi_voice_config", arrValue(listValue(multiVoiceConfigParams...)))) - } - } - params = append(params, kvStr("gen_ai.request.instructions", req.Params.Instructions)) - params = append(params, kvStr("gen_ai.request.response_format", req.Params.ResponseFormat)) - if req.Params.Speed != nil { - params = append(params, kvDbl("gen_ai.request.speed", *req.Params.Speed)) - } - } - if req.Input != nil { - params = append(params, kvStr("gen_ai.input.speech", req.Input.Input)) - } - return params -} - -// getEmbeddingRequestParams handles the embedding request -func getEmbeddingRequestParams(req *schemas.BifrostEmbeddingRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.Dimensions != nil { - params = append(params, kvInt("gen_ai.request.dimensions", int64(*req.Params.Dimensions))) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - if req.Params.EncodingFormat != nil { - params = append(params, kvStr("gen_ai.request.encoding_format", *req.Params.EncodingFormat)) - } - } - if req.Input.Text != nil { - params = append(params, kvStr("gen_ai.input.text", *req.Input.Text)) - } - if req.Input.Texts != nil { - params = append(params, kvStr("gen_ai.input.text", strings.Join(req.Input.Texts, ","))) +// convertTraceToResourceSpan converts a Bifrost trace to OTEL ResourceSpan +func (p *OtelPlugin) convertTraceToResourceSpan(trace *schemas.Trace) *ResourceSpan { + otelSpans := make([]*Span, 0, len(trace.Spans)) + for _, span := range trace.Spans { + otelSpans = append(otelSpans, p.convertSpanToOTELSpan(trace.TraceID, span)) } - if req.Input.Embedding != nil { - embedding := make([]string, len(req.Input.Embedding)) - for i, v := range req.Input.Embedding { - embedding[i] = fmt.Sprintf("%d", v) - } - params = append(params, kvStr("gen_ai.input.embedding", strings.Join(embedding, ","))) - } - return params -} - -// getTextCompletionRequestParams handles the text completion request -func getTextCompletionRequestParams(req *schemas.BifrostTextCompletionRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.MaxTokens != nil { - params = append(params, kvInt("gen_ai.request.max_tokens", int64(*req.Params.MaxTokens))) - } - if req.Params.Temperature != nil { - params = append(params, kvDbl("gen_ai.request.temperature", *req.Params.Temperature)) - } - if req.Params.TopP != nil { - params = append(params, kvDbl("gen_ai.request.top_p", *req.Params.TopP)) - } - if req.Params.Stop != nil { - params = append(params, kvStr("gen_ai.request.stop_sequences", strings.Join(req.Params.Stop, ","))) - } - if req.Params.PresencePenalty != nil { - params = append(params, kvDbl("gen_ai.request.presence_penalty", *req.Params.PresencePenalty)) - } - if req.Params.FrequencyPenalty != nil { - params = append(params, kvDbl("gen_ai.request.frequency_penalty", *req.Params.FrequencyPenalty)) - } - if req.Params.BestOf != nil { - params = append(params, kvInt("gen_ai.request.best_of", int64(*req.Params.BestOf))) - } - if req.Params.Echo != nil { - params = append(params, kvBool("gen_ai.request.echo", *req.Params.Echo)) - } - if req.Params.LogitBias != nil { - params = append(params, kvStr("gen_ai.request.logit_bias", fmt.Sprintf("%v", req.Params.LogitBias))) - } - if req.Params.LogProbs != nil { - params = append(params, kvInt("gen_ai.request.logprobs", int64(*req.Params.LogProbs))) - } - if req.Params.N != nil { - params = append(params, kvInt("gen_ai.request.n", int64(*req.Params.N))) - } - if req.Params.Seed != nil { - params = append(params, kvInt("gen_ai.request.seed", int64(*req.Params.Seed))) - } - if req.Params.Suffix != nil { - params = append(params, kvStr("gen_ai.request.suffix", *req.Params.Suffix)) - } - if req.Params.User != nil { - params = append(params, kvStr("gen_ai.request.user", *req.Params.User)) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - } - if req.Input.PromptStr != nil { - params = append(params, kvStr("gen_ai.input.text", *req.Input.PromptStr)) - } - if req.Input.PromptArray != nil { - params = append(params, kvStr("gen_ai.input.text", strings.Join(req.Input.PromptArray, ","))) - } - return params -} - -// getChatRequestParams handles the chat completion request -func getChatRequestParams(req *schemas.BifrostChatRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.MaxCompletionTokens != nil { - params = append(params, kvInt("gen_ai.request.max_tokens", int64(*req.Params.MaxCompletionTokens))) - } - if req.Params.Temperature != nil { - params = append(params, kvDbl("gen_ai.request.temperature", *req.Params.Temperature)) - } - if req.Params.TopP != nil { - params = append(params, kvDbl("gen_ai.request.top_p", *req.Params.TopP)) - } - if req.Params.Stop != nil { - params = append(params, kvStr("gen_ai.request.stop_sequences", strings.Join(req.Params.Stop, ","))) - } - if req.Params.PresencePenalty != nil { - params = append(params, kvDbl("gen_ai.request.presence_penalty", *req.Params.PresencePenalty)) - } - if req.Params.FrequencyPenalty != nil { - params = append(params, kvDbl("gen_ai.request.frequency_penalty", *req.Params.FrequencyPenalty)) - } - if req.Params.ParallelToolCalls != nil { - params = append(params, kvBool("gen_ai.request.parallel_tool_calls", *req.Params.ParallelToolCalls)) - } - if req.Params.User != nil { - params = append(params, kvStr("gen_ai.request.user", *req.Params.User)) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - } - // Handling chat completion - if req.Input != nil { - messages := []*AnyValue{} - for _, message := range req.Input { - if message.Content == nil { - continue - } - switch message.Role { - case schemas.ChatMessageRoleUser: - kvs := []*KeyValue{kvStr("role", "user")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleAssistant: - kvs := []*KeyValue{kvStr("role", "assistant")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleSystem: - kvs := []*KeyValue{kvStr("role", "system")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleTool: - kvs := []*KeyValue{kvStr("role", "tool")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleDeveloper: - kvs := []*KeyValue{kvStr("role", "developer")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - } - } - params = append(params, kvAny("gen_ai.input.messages", arrValue(messages...))) - } - return params -} -// getTranscriptionRequestParams handles the transcription request -func getTranscriptionRequestParams(req *schemas.BifrostTranscriptionRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.Language != nil { - params = append(params, kvStr("gen_ai.request.language", *req.Params.Language)) - } - if req.Params.Prompt != nil { - params = append(params, kvStr("gen_ai.request.prompt", *req.Params.Prompt)) - } - if req.Params.ResponseFormat != nil { - params = append(params, kvStr("gen_ai.request.response_format", *req.Params.ResponseFormat)) - } - if req.Params.Format != nil { - params = append(params, kvStr("gen_ai.request.format", *req.Params.Format)) - } - } - return params -} - -// getResponsesRequestParams handles the responses request -func getResponsesRequestParams(req *schemas.BifrostResponsesRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.ParallelToolCalls != nil { - params = append(params, kvBool("gen_ai.request.parallel_tool_calls", *req.Params.ParallelToolCalls)) - } - if req.Params.PromptCacheKey != nil { - params = append(params, kvStr("gen_ai.request.prompt_cache_key", *req.Params.PromptCacheKey)) - } - if req.Params.Reasoning != nil { - if req.Params.Reasoning.Effort != nil { - params = append(params, kvStr("gen_ai.request.reasoning_effort", *req.Params.Reasoning.Effort)) - } - if req.Params.Reasoning.Summary != nil { - params = append(params, kvStr("gen_ai.request.reasoning_summary", *req.Params.Reasoning.Summary)) - } - if req.Params.Reasoning.GenerateSummary != nil { - params = append(params, kvStr("gen_ai.request.reasoning_generate_summary", *req.Params.Reasoning.GenerateSummary)) - } - } - if req.Params.SafetyIdentifier != nil { - params = append(params, kvStr("gen_ai.request.safety_identifier", *req.Params.SafetyIdentifier)) - } - if req.Params.ServiceTier != nil { - params = append(params, kvStr("gen_ai.request.service_tier", *req.Params.ServiceTier)) - } - if req.Params.Store != nil { - params = append(params, kvBool("gen_ai.request.store", *req.Params.Store)) - } - if req.Params.Temperature != nil { - params = append(params, kvDbl("gen_ai.request.temperature", *req.Params.Temperature)) - } - if req.Params.Text != nil { - if req.Params.Text.Verbosity != nil { - params = append(params, kvStr("gen_ai.request.text", *req.Params.Text.Verbosity)) - } - if req.Params.Text.Format != nil { - params = append(params, kvStr("gen_ai.request.text_format_type", req.Params.Text.Format.Type)) - } - - } - if req.Params.TopLogProbs != nil { - params = append(params, kvInt("gen_ai.request.top_logprobs", int64(*req.Params.TopLogProbs))) - } - if req.Params.TopP != nil { - params = append(params, kvDbl("gen_ai.request.top_p", *req.Params.TopP)) - } - if req.Params.ToolChoice != nil { - if req.Params.ToolChoice.ResponsesToolChoiceStr != nil && *req.Params.ToolChoice.ResponsesToolChoiceStr != "" { - params = append(params, kvStr("gen_ai.request.tool_choice_type", *req.Params.ToolChoice.ResponsesToolChoiceStr)) - } - if req.Params.ToolChoice.ResponsesToolChoiceStruct != nil && req.Params.ToolChoice.ResponsesToolChoiceStruct.Name != nil { - params = append(params, kvStr("gen_ai.request.tool_choice_name", *req.Params.ToolChoice.ResponsesToolChoiceStruct.Name)) - } - - } - if req.Params.Tools != nil { - tools := make([]string, len(req.Params.Tools)) - for i, tool := range req.Params.Tools { - tools[i] = string(tool.Type) - } - params = append(params, kvStr("gen_ai.request.tools", strings.Join(tools, ","))) - } - if req.Params.Truncation != nil { - params = append(params, kvStr("gen_ai.request.truncation", *req.Params.Truncation)) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - } - return params -} - -// getFileUploadRequestParams handles the file upload request -func getFileUploadRequestParams(req *schemas.BifrostFileUploadRequest) []*KeyValue { - params := []*KeyValue{} - if req.Filename != "" { - params = append(params, kvStr("gen_ai.file.filename", req.Filename)) - } - if req.Purpose != "" { - params = append(params, kvStr("gen_ai.file.purpose", string(req.Purpose))) - } - if len(req.File) > 0 { - params = append(params, kvInt("gen_ai.file.bytes", int64(len(req.File)))) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params -} - -// getFileListRequestParams handles the file list request -func getFileListRequestParams(req *schemas.BifrostFileListRequest) []*KeyValue { - params := []*KeyValue{} - if req.Purpose != "" { - params = append(params, kvStr("gen_ai.file.purpose", string(req.Purpose))) - } - if req.Limit > 0 { - params = append(params, kvInt("gen_ai.file.limit", int64(req.Limit))) - } - if req.After != nil { - params = append(params, kvStr("gen_ai.file.after", *req.After)) - } - if req.Order != nil { - params = append(params, kvStr("gen_ai.file.order", *req.Order)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } + return &ResourceSpan{ + Resource: &resourcepb.Resource{ + Attributes: p.getResourceAttributes(), + }, + ScopeSpans: []*ScopeSpan{{ + Scope: p.getInstrumentationScope(), + Spans: otelSpans, + }}, } - return params } -// getFileRetrieveRequestParams handles the file retrieve request -func getFileRetrieveRequestParams(req *schemas.BifrostFileRetrieveRequest) []*KeyValue { - params := []*KeyValue{} - if req.FileID != "" { - params = append(params, kvStr("gen_ai.file.file_id", req.FileID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } +// convertSpanToOTELSpan converts a single Bifrost span to OTEL format +func (p *OtelPlugin) convertSpanToOTELSpan(traceID string, span *schemas.Span) *Span { + otelSpan := &Span{ + TraceId: hexToBytes(traceID, 16), + SpanId: hexToBytes(span.SpanID, 8), + Name: span.Name, + Kind: convertSpanKind(span.Kind), + StartTimeUnixNano: uint64(span.StartTime.UnixNano()), + EndTimeUnixNano: uint64(span.EndTime.UnixNano()), + Attributes: convertAttributesToKeyValues(span.Attributes), + Status: convertSpanStatus(span.Status, span.StatusMsg), + Events: convertSpanEvents(span.Events), } - return params -} -// getFileDeleteRequestParams handles the file delete request -func getFileDeleteRequestParams(req *schemas.BifrostFileDeleteRequest) []*KeyValue { - params := []*KeyValue{} - if req.FileID != "" { - params = append(params, kvStr("gen_ai.file.file_id", req.FileID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } + // Set parent span ID if present + if span.ParentID != "" { + otelSpan.ParentSpanId = hexToBytes(span.ParentID, 8) } - return params -} -// getFileContentRequestParams handles the file content request -func getFileContentRequestParams(req *schemas.BifrostFileContentRequest) []*KeyValue { - params := []*KeyValue{} - if req.FileID != "" { - params = append(params, kvStr("gen_ai.file.file_id", req.FileID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params + return otelSpan } -// getBatchCreateRequestParams handles the batch create request -func getBatchCreateRequestParams(req *schemas.BifrostBatchCreateRequest) []*KeyValue { - params := []*KeyValue{} - if req.InputFileID != "" { - params = append(params, kvStr("gen_ai.batch.input_file_id", req.InputFileID)) - } - if req.Endpoint != "" { - params = append(params, kvStr("gen_ai.batch.endpoint", string(req.Endpoint))) - } - if req.CompletionWindow != "" { - params = append(params, kvStr("gen_ai.batch.completion_window", req.CompletionWindow)) - } - if len(req.Requests) > 0 { - params = append(params, kvInt("gen_ai.batch.requests_count", int64(len(req.Requests)))) - } - if len(req.Metadata) > 0 { - params = append(params, kvStr("gen_ai.batch.metadata", fmt.Sprintf("%v", req.Metadata))) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params +// getResourceAttributes returns the resource attributes for the OTEL span +func (p *OtelPlugin) getResourceAttributes() []*KeyValue { + attrs := []*KeyValue{ + kvStr("service.name", p.serviceName), + kvStr("service.version", p.bifrostVersion), + kvStr("telemetry.sdk.name", "bifrost"), + kvStr("telemetry.sdk.language", "go"), + } + // Add environment attributes + attrs = append(attrs, p.attributesFromEnvironment...) + return attrs } -// getBatchListRequestParams handles the batch list request -func getBatchListRequestParams(req *schemas.BifrostBatchListRequest) []*KeyValue { - params := []*KeyValue{} - if req.Limit > 0 { - params = append(params, kvInt("gen_ai.batch.limit", int64(req.Limit))) - } - if req.After != nil { - params = append(params, kvStr("gen_ai.batch.after", *req.After)) - } - if req.BeforeID != nil { - params = append(params, kvStr("gen_ai.batch.before_id", *req.BeforeID)) +// getInstrumentationScope returns the instrumentation scope for OTEL +func (p *OtelPlugin) getInstrumentationScope() *commonpb.InstrumentationScope { + return &commonpb.InstrumentationScope{ + Name: p.serviceName, + Version: p.bifrostVersion, } - if req.AfterID != nil { - params = append(params, kvStr("gen_ai.batch.after_id", *req.AfterID)) - } - if req.PageToken != nil { - params = append(params, kvStr("gen_ai.batch.page_token", *req.PageToken)) - } - if req.PageSize > 0 { - params = append(params, kvInt("gen_ai.batch.page_size", int64(req.PageSize))) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params } -// getBatchRetrieveRequestParams handles the batch retrieve request -func getBatchRetrieveRequestParams(req *schemas.BifrostBatchRetrieveRequest) []*KeyValue { - params := []*KeyValue{} - if req.BatchID != "" { - params = append(params, kvStr("gen_ai.batch.batch_id", req.BatchID)) +// convertAttributesToKeyValues converts map[string]any to OTEL KeyValue slice +func convertAttributesToKeyValues(attrs map[string]any) []*KeyValue { + if attrs == nil { + return nil } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) + kvs := make([]*KeyValue, 0, len(attrs)) + for k, v := range attrs { + kv := anyToKeyValue(k, v) + if kv != nil { + kvs = append(kvs, kv) } } - return params + return kvs } -// getBatchCancelRequestParams handles the batch cancel request -func getBatchCancelRequestParams(req *schemas.BifrostBatchCancelRequest) []*KeyValue { - params := []*KeyValue{} - if req.BatchID != "" { - params = append(params, kvStr("gen_ai.batch.batch_id", req.BatchID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } +// anyToKeyValue converts any Go value to OTEL KeyValue +func anyToKeyValue(key string, value any) *KeyValue { + if value == nil { + return nil + } + switch v := value.(type) { + case string: + if v == "" { + return nil + } + return kvStr(key, v) + case int: + return kvInt(key, int64(v)) + case int32: + return kvInt(key, int64(v)) + case int64: + return kvInt(key, v) + case uint: + return kvInt(key, int64(v)) + case uint32: + return kvInt(key, int64(v)) + case uint64: + return kvInt(key, int64(v)) + case float32: + return kvDbl(key, float64(v)) + case float64: + return kvDbl(key, v) + case bool: + return kvBool(key, v) + case []string: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, s := range v { + vals[i] = &AnyValue{Value: &StringValue{StringValue: s}} + } + return kvAny(key, arrValue(vals...)) + case []int: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, n := range v { + vals[i] = &AnyValue{Value: &IntValue{IntValue: int64(n)}} + } + return kvAny(key, arrValue(vals...)) + case []int64: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, n := range v { + vals[i] = &AnyValue{Value: &IntValue{IntValue: n}} + } + return kvAny(key, arrValue(vals...)) + case []float64: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, n := range v { + vals[i] = &AnyValue{Value: &DoubleValue{DoubleValue: n}} + } + return kvAny(key, arrValue(vals...)) + case map[string]any: + if len(v) == 0 { + return nil + } + kvList := make([]*KeyValue, 0, len(v)) + for k, val := range v { + kv := anyToKeyValue(k, val) + if kv != nil { + kvList = append(kvList, kv) + } + } + return kvAny(key, listValue(kvList...)) + default: + // For any other type, convert to string + return kvStr(key, fmt.Sprintf("%v", v)) } - return params } -// getBatchResultsRequestParams handles the batch results request -func getBatchResultsRequestParams(req *schemas.BifrostBatchResultsRequest) []*KeyValue { - params := []*KeyValue{} - if req.BatchID != "" { - params = append(params, kvStr("gen_ai.batch.batch_id", req.BatchID)) +// convertSpanKind maps Bifrost SpanKind to OTEL SpanKind +func convertSpanKind(kind schemas.SpanKind) tracepb.Span_SpanKind { + switch kind { + case schemas.SpanKindLLMCall: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindHTTPRequest: + return tracepb.Span_SPAN_KIND_SERVER + case schemas.SpanKindPlugin: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindInternal: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindRetry: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindFallback: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindMCPTool: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindEmbedding: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindSpeech: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindTranscription: + return tracepb.Span_SPAN_KIND_CLIENT + default: + return tracepb.Span_SPAN_KIND_UNSPECIFIED } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params } -// createResourceSpan creates a new resource span for a Bifrost request -func (p *OtelPlugin) createResourceSpan(traceID, spanID string, timestamp time.Time, req *schemas.BifrostRequest) *ResourceSpan { - provider, model, _ := req.GetRequestFields() - - // preparing parameters - params := []*KeyValue{} - spanName := "span" - params = append(params, kvStr("gen_ai.provider.name", string(provider))) - params = append(params, kvStr("gen_ai.request.model", model)) - // Preparing parameters - switch req.RequestType { - case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: - spanName = "gen_ai.text" - params = append(params, getTextCompletionRequestParams(req.TextCompletionRequest)...) - case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: - spanName = "gen_ai.chat" - params = append(params, getChatRequestParams(req.ChatRequest)...) - case schemas.EmbeddingRequest: - spanName = "gen_ai.embedding" - params = append(params, getEmbeddingRequestParams(req.EmbeddingRequest)...) - case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: - spanName = "gen_ai.transcription" - params = append(params, getTranscriptionRequestParams(req.TranscriptionRequest)...) - case schemas.SpeechRequest, schemas.SpeechStreamRequest: - spanName = "gen_ai.speech" - params = append(params, getSpeechRequestParams(req.SpeechRequest)...) - case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: - spanName = "gen_ai.responses" - params = append(params, getResponsesRequestParams(req.ResponsesRequest)...) - case schemas.BatchCreateRequest: - spanName = "gen_ai.batch.create" - params = append(params, getBatchCreateRequestParams(req.BatchCreateRequest)...) - case schemas.BatchListRequest: - spanName = "gen_ai.batch.list" - params = append(params, getBatchListRequestParams(req.BatchListRequest)...) - case schemas.BatchRetrieveRequest: - spanName = "gen_ai.batch.retrieve" - params = append(params, getBatchRetrieveRequestParams(req.BatchRetrieveRequest)...) - case schemas.BatchCancelRequest: - spanName = "gen_ai.batch.cancel" - params = append(params, getBatchCancelRequestParams(req.BatchCancelRequest)...) - case schemas.BatchResultsRequest: - spanName = "gen_ai.batch.results" - params = append(params, getBatchResultsRequestParams(req.BatchResultsRequest)...) - case schemas.FileUploadRequest: - spanName = "gen_ai.file.upload" - params = append(params, getFileUploadRequestParams(req.FileUploadRequest)...) - case schemas.FileListRequest: - spanName = "gen_ai.file.list" - params = append(params, getFileListRequestParams(req.FileListRequest)...) - case schemas.FileRetrieveRequest: - spanName = "gen_ai.file.retrieve" - params = append(params, getFileRetrieveRequestParams(req.FileRetrieveRequest)...) - case schemas.FileDeleteRequest: - spanName = "gen_ai.file.delete" - params = append(params, getFileDeleteRequestParams(req.FileDeleteRequest)...) - case schemas.FileContentRequest: - spanName = "gen_ai.file.content" - params = append(params, getFileContentRequestParams(req.FileContentRequest)...) - } - attributes := append(p.attributesFromEnvironment, kvStr("service.name", p.serviceName), kvStr("service.version", p.bifrostVersion)) - // Preparing final resource span - return &ResourceSpan{ - Resource: &resourcepb.Resource{ - Attributes: attributes, - }, - ScopeSpans: []*ScopeSpan{ - { - Scope: &commonpb.InstrumentationScope{ - Name: "bifrost-otel-plugin", - }, - Spans: []*Span{ - { - TraceId: hexToBytes(traceID, 16), - SpanId: hexToBytes(spanID, 8), - Kind: tracepb.Span_SPAN_KIND_SERVER, - StartTimeUnixNano: uint64(timestamp.UnixNano()), - EndTimeUnixNano: uint64(timestamp.UnixNano()), - Name: spanName, - Attributes: params, - }, - }, - }, - }, +// convertSpanStatus maps Bifrost SpanStatus to OTEL Status +func convertSpanStatus(status schemas.SpanStatus, msg string) *tracepb.Status { + switch status { + case schemas.SpanStatusOk: + return &tracepb.Status{Code: tracepb.Status_STATUS_CODE_OK} + case schemas.SpanStatusError: + return &tracepb.Status{Code: tracepb.Status_STATUS_CODE_ERROR, Message: msg} + default: + return &tracepb.Status{Code: tracepb.Status_STATUS_CODE_UNSET} } } -// completeResourceSpan completes a resource span for a Bifrost response -func completeResourceSpan( - span *ResourceSpan, - timestamp time.Time, - resp *schemas.BifrostResponse, - bifrostErr *schemas.BifrostError, - pricingManager *modelcatalog.ModelCatalog, - virtualKeyID string, - virtualKeyName string, - selectedKeyID string, - selectedKeyName string, - numberOfRetries int, - fallbackIndex int, - teamID string, - teamName string, - customerID string, - customerName string, -) *ResourceSpan { - params := []*KeyValue{} - - if resp != nil { - switch { // Accumulator wont return stream type responses - case resp.TextCompletionResponse != nil: - params = append(params, kvStr("gen_ai.text.id", resp.TextCompletionResponse.ID)) - params = append(params, kvStr("gen_ai.text.model", resp.TextCompletionResponse.Model)) - params = append(params, kvStr("gen_ai.text.object", resp.TextCompletionResponse.Object)) - params = append(params, kvStr("gen_ai.text.system_fingerprint", resp.TextCompletionResponse.SystemFingerprint)) - outputMessages := []*AnyValue{} - for _, choice := range resp.TextCompletionResponse.Choices { - if choice.TextCompletionResponseChoice == nil { - continue - } - kvs := []*KeyValue{kvStr("role", string(schemas.ChatMessageRoleAssistant))} - if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { - kvs = append(kvs, kvStr("content", *choice.TextCompletionResponseChoice.Text)) - } - outputMessages = append(outputMessages, listValue(kvs...)) - } - params = append(params, kvAny("gen_ai.text.output_messages", arrValue(outputMessages...))) - if resp.TextCompletionResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.prompt_tokens", int64(resp.TextCompletionResponse.Usage.PromptTokens))) - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(resp.TextCompletionResponse.Usage.CompletionTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.TextCompletionResponse.Usage.TotalTokens))) - } - // Computing cost - if pricingManager != nil { - cost := pricingManager.CalculateCostWithCacheDebug(resp) - params = append(params, kvDbl("gen_ai.usage.cost", cost)) - } - case resp.ChatResponse != nil: - params = append(params, kvStr("gen_ai.chat.id", resp.ChatResponse.ID)) - params = append(params, kvStr("gen_ai.chat.model", resp.ChatResponse.Model)) - params = append(params, kvStr("gen_ai.chat.object", resp.ChatResponse.Object)) - params = append(params, kvStr("gen_ai.chat.system_fingerprint", resp.ChatResponse.SystemFingerprint)) - params = append(params, kvStr("gen_ai.chat.created", fmt.Sprintf("%d", resp.ChatResponse.Created))) - if resp.ChatResponse.ServiceTier != nil { - params = append(params, kvStr("gen_ai.chat.service_tier", *resp.ChatResponse.ServiceTier)) - } - outputMessages := []*AnyValue{} - for _, choice := range resp.ChatResponse.Choices { - var role string - if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil && choice.ChatNonStreamResponseChoice.Message.Role != "" { - role = string(choice.ChatNonStreamResponseChoice.Message.Role) - } else { - role = string(schemas.ChatMessageRoleAssistant) - } - kvs := []*KeyValue{kvStr("role", role)} - - if choice.ChatNonStreamResponseChoice != nil && - choice.ChatNonStreamResponseChoice.Message != nil && - choice.ChatNonStreamResponseChoice.Message.Content != nil { - if choice.ChatNonStreamResponseChoice.Message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *choice.ChatNonStreamResponseChoice.Message.Content.ContentStr)) - } else if choice.ChatNonStreamResponseChoice.Message.Content.ContentBlocks != nil { - blockText := "" - for _, block := range choice.ChatNonStreamResponseChoice.Message.Content.ContentBlocks { - if block.Text != nil { - blockText += *block.Text - } - } - kvs = append(kvs, kvStr("content", blockText)) - } - } - outputMessages = append(outputMessages, listValue(kvs...)) - } - params = append(params, kvAny("gen_ai.chat.output_messages", arrValue(outputMessages...))) - if resp.ChatResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.prompt_tokens", int64(resp.ChatResponse.Usage.PromptTokens))) - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(resp.ChatResponse.Usage.CompletionTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.ChatResponse.Usage.TotalTokens))) - } - // Computing cost - if pricingManager != nil { - cost := pricingManager.CalculateCostWithCacheDebug(resp) - params = append(params, kvDbl("gen_ai.usage.cost", cost)) - } - case resp.ResponsesResponse != nil: - outputMessages := []*AnyValue{} - for _, message := range resp.ResponsesResponse.Output { - if message.Role == nil { - continue - } - kvs := []*KeyValue{kvStr("role", string(*message.Role))} - if message.Content != nil { - if message.Content.ContentStr != nil && *message.Content.ContentStr != "" { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } else if message.Content.ContentBlocks != nil { - blockText := "" - for _, block := range message.Content.ContentBlocks { - if block.Text != nil { - blockText += *block.Text - } - } - kvs = append(kvs, kvStr("content", blockText)) - } - } - if message.ResponsesReasoning != nil && message.ResponsesReasoning.Summary != nil { - reasoningText := "" - for _, block := range message.ResponsesReasoning.Summary { - if block.Text != "" { - reasoningText += block.Text - } - } - kvs = append(kvs, kvStr("reasoning", reasoningText)) - } - outputMessages = append(outputMessages, listValue(kvs...)) - - } - params = append(params, kvAny("gen_ai.responses.output_messages", arrValue(outputMessages...))) - - responsesResponse := resp.ResponsesResponse - if responsesResponse.Include != nil { - params = append(params, kvStr("gen_ai.responses.include", strings.Join(responsesResponse.Include, ","))) - } - if responsesResponse.MaxOutputTokens != nil { - params = append(params, kvInt("gen_ai.responses.max_output_tokens", int64(*responsesResponse.MaxOutputTokens))) - } - if responsesResponse.MaxToolCalls != nil { - params = append(params, kvInt("gen_ai.responses.max_tool_calls", int64(*responsesResponse.MaxToolCalls))) - } - if responsesResponse.Metadata != nil { - params = append(params, kvStr("gen_ai.responses.metadata", fmt.Sprintf("%v", responsesResponse.Metadata))) - } - if responsesResponse.PreviousResponseID != nil { - params = append(params, kvStr("gen_ai.responses.previous_response_id", *responsesResponse.PreviousResponseID)) - } - if responsesResponse.PromptCacheKey != nil { - params = append(params, kvStr("gen_ai.responses.prompt_cache_key", *responsesResponse.PromptCacheKey)) - } - if responsesResponse.Reasoning != nil { - if responsesResponse.Reasoning.Summary != nil { - params = append(params, kvStr("gen_ai.responses.reasoning", *responsesResponse.Reasoning.Summary)) - } - if responsesResponse.Reasoning.Effort != nil { - params = append(params, kvStr("gen_ai.responses.reasoning_effort", *responsesResponse.Reasoning.Effort)) - } - if responsesResponse.Reasoning.GenerateSummary != nil { - params = append(params, kvStr("gen_ai.responses.reasoning_generate_summary", *responsesResponse.Reasoning.GenerateSummary)) - } - } - if responsesResponse.SafetyIdentifier != nil { - params = append(params, kvStr("gen_ai.responses.safety_identifier", *responsesResponse.SafetyIdentifier)) - } - if responsesResponse.ServiceTier != nil { - params = append(params, kvStr("gen_ai.responses.service_tier", *responsesResponse.ServiceTier)) - } - if responsesResponse.Store != nil { - params = append(params, kvBool("gen_ai.responses.store", *responsesResponse.Store)) - } - if responsesResponse.Temperature != nil { - params = append(params, kvDbl("gen_ai.responses.temperature", *responsesResponse.Temperature)) - } - if responsesResponse.Text != nil { - if responsesResponse.Text.Verbosity != nil { - params = append(params, kvStr("gen_ai.responses.text", *responsesResponse.Text.Verbosity)) - } - if responsesResponse.Text.Format != nil { - params = append(params, kvStr("gen_ai.responses.text_format_type", responsesResponse.Text.Format.Type)) - } - } - if responsesResponse.TopLogProbs != nil { - params = append(params, kvInt("gen_ai.responses.top_logprobs", int64(*responsesResponse.TopLogProbs))) - } - if responsesResponse.TopP != nil { - params = append(params, kvDbl("gen_ai.responses.top_p", *responsesResponse.TopP)) - } - if responsesResponse.ToolChoice != nil { - if responsesResponse.ToolChoice.ResponsesToolChoiceStruct != nil && responsesResponse.ToolChoice.ResponsesToolChoiceStr != nil { - params = append(params, kvStr("gen_ai.responses.tool_choice_type", *responsesResponse.ToolChoice.ResponsesToolChoiceStr)) - } - if responsesResponse.ToolChoice.ResponsesToolChoiceStruct != nil && responsesResponse.ToolChoice.ResponsesToolChoiceStruct.Name != nil { - params = append(params, kvStr("gen_ai.responses.tool_choice_name", *responsesResponse.ToolChoice.ResponsesToolChoiceStruct.Name)) - } - } - if responsesResponse.Truncation != nil { - params = append(params, kvStr("gen_ai.responses.truncation", *responsesResponse.Truncation)) - } - if responsesResponse.Tools != nil { - tools := make([]string, len(responsesResponse.Tools)) - for i, tool := range responsesResponse.Tools { - tools[i] = string(tool.Type) - } - params = append(params, kvStr("gen_ai.responses.tools", strings.Join(tools, ","))) - } - case resp.EmbeddingResponse != nil: - if resp.EmbeddingResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.prompt_tokens", int64(resp.EmbeddingResponse.Usage.PromptTokens))) - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(resp.EmbeddingResponse.Usage.CompletionTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.EmbeddingResponse.Usage.TotalTokens))) - } - case resp.SpeechResponse != nil: - if resp.SpeechResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.input_tokens", int64(resp.SpeechResponse.Usage.InputTokens))) - params = append(params, kvInt("gen_ai.usage.output_tokens", int64(resp.SpeechResponse.Usage.OutputTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.SpeechResponse.Usage.TotalTokens))) - } - case resp.TranscriptionResponse != nil: - outputMessages := []*AnyValue{} - kvs := []*KeyValue{kvStr("text", resp.TranscriptionResponse.Text)} - outputMessages = append(outputMessages, listValue(kvs...)) - params = append(params, kvAny("gen_ai.transcribe.output_messages", arrValue(outputMessages...))) - if resp.TranscriptionResponse.Usage != nil { - if resp.TranscriptionResponse.Usage.InputTokens != nil { - params = append(params, kvInt("gen_ai.usage.input_tokens", int64(*resp.TranscriptionResponse.Usage.InputTokens))) - } - if resp.TranscriptionResponse.Usage.OutputTokens != nil { - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(*resp.TranscriptionResponse.Usage.OutputTokens))) - } - if resp.TranscriptionResponse.Usage.TotalTokens != nil { - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(*resp.TranscriptionResponse.Usage.TotalTokens))) - } - if resp.TranscriptionResponse.Usage.InputTokenDetails != nil { - params = append(params, kvInt("gen_ai.usage.input_token_details.text_tokens", int64(resp.TranscriptionResponse.Usage.InputTokenDetails.TextTokens))) - params = append(params, kvInt("gen_ai.usage.input_token_details.audio_tokens", int64(resp.TranscriptionResponse.Usage.InputTokenDetails.AudioTokens))) - } - } - case resp.BatchCreateResponse != nil: - params = append(params, kvStr("gen_ai.batch.id", resp.BatchCreateResponse.ID)) - params = append(params, kvStr("gen_ai.batch.status", string(resp.BatchCreateResponse.Status))) - if resp.BatchCreateResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchCreateResponse.Object)) - } - if resp.BatchCreateResponse.Endpoint != "" { - params = append(params, kvStr("gen_ai.batch.endpoint", resp.BatchCreateResponse.Endpoint)) - } - if resp.BatchCreateResponse.InputFileID != "" { - params = append(params, kvStr("gen_ai.batch.input_file_id", resp.BatchCreateResponse.InputFileID)) - } - if resp.BatchCreateResponse.CompletionWindow != "" { - params = append(params, kvStr("gen_ai.batch.completion_window", resp.BatchCreateResponse.CompletionWindow)) - } - if resp.BatchCreateResponse.CreatedAt != 0 { - params = append(params, kvInt("gen_ai.batch.created_at", resp.BatchCreateResponse.CreatedAt)) - } - if resp.BatchCreateResponse.ExpiresAt != nil { - params = append(params, kvInt("gen_ai.batch.expires_at", *resp.BatchCreateResponse.ExpiresAt)) - } - if resp.BatchCreateResponse.OutputFileID != nil { - params = append(params, kvStr("gen_ai.batch.output_file_id", *resp.BatchCreateResponse.OutputFileID)) - } - if resp.BatchCreateResponse.ErrorFileID != nil { - params = append(params, kvStr("gen_ai.batch.error_file_id", *resp.BatchCreateResponse.ErrorFileID)) - } - params = append(params, kvInt("gen_ai.batch.request_counts.total", int64(resp.BatchCreateResponse.RequestCounts.Total))) - params = append(params, kvInt("gen_ai.batch.request_counts.completed", int64(resp.BatchCreateResponse.RequestCounts.Completed))) - params = append(params, kvInt("gen_ai.batch.request_counts.failed", int64(resp.BatchCreateResponse.RequestCounts.Failed))) - case resp.BatchListResponse != nil: - if resp.BatchListResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchListResponse.Object)) - } - params = append(params, kvInt("gen_ai.batch.data_count", int64(len(resp.BatchListResponse.Data)))) - params = append(params, kvBool("gen_ai.batch.has_more", resp.BatchListResponse.HasMore)) - if resp.BatchListResponse.FirstID != nil { - params = append(params, kvStr("gen_ai.batch.first_id", *resp.BatchListResponse.FirstID)) - } - if resp.BatchListResponse.LastID != nil { - params = append(params, kvStr("gen_ai.batch.last_id", *resp.BatchListResponse.LastID)) - } - case resp.BatchRetrieveResponse != nil: - params = append(params, kvStr("gen_ai.batch.id", resp.BatchRetrieveResponse.ID)) - params = append(params, kvStr("gen_ai.batch.status", string(resp.BatchRetrieveResponse.Status))) - if resp.BatchRetrieveResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchRetrieveResponse.Object)) - } - if resp.BatchRetrieveResponse.Endpoint != "" { - params = append(params, kvStr("gen_ai.batch.endpoint", resp.BatchRetrieveResponse.Endpoint)) - } - if resp.BatchRetrieveResponse.InputFileID != "" { - params = append(params, kvStr("gen_ai.batch.input_file_id", resp.BatchRetrieveResponse.InputFileID)) - } - if resp.BatchRetrieveResponse.CompletionWindow != "" { - params = append(params, kvStr("gen_ai.batch.completion_window", resp.BatchRetrieveResponse.CompletionWindow)) - } - if resp.BatchRetrieveResponse.CreatedAt != 0 { - params = append(params, kvInt("gen_ai.batch.created_at", resp.BatchRetrieveResponse.CreatedAt)) - } - if resp.BatchRetrieveResponse.ExpiresAt != nil { - params = append(params, kvInt("gen_ai.batch.expires_at", *resp.BatchRetrieveResponse.ExpiresAt)) - } - if resp.BatchRetrieveResponse.InProgressAt != nil { - params = append(params, kvInt("gen_ai.batch.in_progress_at", *resp.BatchRetrieveResponse.InProgressAt)) - } - if resp.BatchRetrieveResponse.FinalizingAt != nil { - params = append(params, kvInt("gen_ai.batch.finalizing_at", *resp.BatchRetrieveResponse.FinalizingAt)) - } - if resp.BatchRetrieveResponse.CompletedAt != nil { - params = append(params, kvInt("gen_ai.batch.completed_at", *resp.BatchRetrieveResponse.CompletedAt)) - } - if resp.BatchRetrieveResponse.FailedAt != nil { - params = append(params, kvInt("gen_ai.batch.failed_at", *resp.BatchRetrieveResponse.FailedAt)) - } - if resp.BatchRetrieveResponse.ExpiredAt != nil { - params = append(params, kvInt("gen_ai.batch.expired_at", *resp.BatchRetrieveResponse.ExpiredAt)) - } - if resp.BatchRetrieveResponse.CancellingAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelling_at", *resp.BatchRetrieveResponse.CancellingAt)) - } - if resp.BatchRetrieveResponse.CancelledAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelled_at", *resp.BatchRetrieveResponse.CancelledAt)) - } - if resp.BatchRetrieveResponse.OutputFileID != nil { - params = append(params, kvStr("gen_ai.batch.output_file_id", *resp.BatchRetrieveResponse.OutputFileID)) - } - if resp.BatchRetrieveResponse.ErrorFileID != nil { - params = append(params, kvStr("gen_ai.batch.error_file_id", *resp.BatchRetrieveResponse.ErrorFileID)) - } - params = append(params, kvInt("gen_ai.batch.request_counts.total", int64(resp.BatchRetrieveResponse.RequestCounts.Total))) - params = append(params, kvInt("gen_ai.batch.request_counts.completed", int64(resp.BatchRetrieveResponse.RequestCounts.Completed))) - params = append(params, kvInt("gen_ai.batch.request_counts.failed", int64(resp.BatchRetrieveResponse.RequestCounts.Failed))) - case resp.BatchCancelResponse != nil: - params = append(params, kvStr("gen_ai.batch.id", resp.BatchCancelResponse.ID)) - params = append(params, kvStr("gen_ai.batch.status", string(resp.BatchCancelResponse.Status))) - if resp.BatchCancelResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchCancelResponse.Object)) - } - if resp.BatchCancelResponse.CancellingAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelling_at", *resp.BatchCancelResponse.CancellingAt)) - } - if resp.BatchCancelResponse.CancelledAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelled_at", *resp.BatchCancelResponse.CancelledAt)) - } - params = append(params, kvInt("gen_ai.batch.request_counts.total", int64(resp.BatchCancelResponse.RequestCounts.Total))) - params = append(params, kvInt("gen_ai.batch.request_counts.completed", int64(resp.BatchCancelResponse.RequestCounts.Completed))) - params = append(params, kvInt("gen_ai.batch.request_counts.failed", int64(resp.BatchCancelResponse.RequestCounts.Failed))) - case resp.BatchResultsResponse != nil: - params = append(params, kvStr("gen_ai.batch.batch_id", resp.BatchResultsResponse.BatchID)) - params = append(params, kvInt("gen_ai.batch.results_count", int64(len(resp.BatchResultsResponse.Results)))) - params = append(params, kvBool("gen_ai.batch.has_more", resp.BatchResultsResponse.HasMore)) - if resp.BatchResultsResponse.NextCursor != nil { - params = append(params, kvStr("gen_ai.batch.next_cursor", *resp.BatchResultsResponse.NextCursor)) - } - case resp.FileUploadResponse != nil: - params = append(params, kvStr("gen_ai.file.id", resp.FileUploadResponse.ID)) - if resp.FileUploadResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileUploadResponse.Object)) - } - params = append(params, kvInt("gen_ai.file.bytes", resp.FileUploadResponse.Bytes)) - params = append(params, kvInt("gen_ai.file.created_at", resp.FileUploadResponse.CreatedAt)) - params = append(params, kvStr("gen_ai.file.filename", resp.FileUploadResponse.Filename)) - params = append(params, kvStr("gen_ai.file.purpose", string(resp.FileUploadResponse.Purpose))) - if resp.FileUploadResponse.Status != "" { - params = append(params, kvStr("gen_ai.file.status", string(resp.FileUploadResponse.Status))) - } - if resp.FileUploadResponse.StorageBackend != "" { - params = append(params, kvStr("gen_ai.file.storage_backend", string(resp.FileUploadResponse.StorageBackend))) - } - case resp.FileListResponse != nil: - if resp.FileListResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileListResponse.Object)) - } - params = append(params, kvInt("gen_ai.file.data_count", int64(len(resp.FileListResponse.Data)))) - params = append(params, kvBool("gen_ai.file.has_more", resp.FileListResponse.HasMore)) - case resp.FileRetrieveResponse != nil: - params = append(params, kvStr("gen_ai.file.id", resp.FileRetrieveResponse.ID)) - if resp.FileRetrieveResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileRetrieveResponse.Object)) - } - params = append(params, kvInt("gen_ai.file.bytes", resp.FileRetrieveResponse.Bytes)) - params = append(params, kvInt("gen_ai.file.created_at", resp.FileRetrieveResponse.CreatedAt)) - params = append(params, kvStr("gen_ai.file.filename", resp.FileRetrieveResponse.Filename)) - params = append(params, kvStr("gen_ai.file.purpose", string(resp.FileRetrieveResponse.Purpose))) - if resp.FileRetrieveResponse.Status != "" { - params = append(params, kvStr("gen_ai.file.status", string(resp.FileRetrieveResponse.Status))) - } - if resp.FileRetrieveResponse.StorageBackend != "" { - params = append(params, kvStr("gen_ai.file.storage_backend", string(resp.FileRetrieveResponse.StorageBackend))) - } - case resp.FileDeleteResponse != nil: - params = append(params, kvStr("gen_ai.file.id", resp.FileDeleteResponse.ID)) - if resp.FileDeleteResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileDeleteResponse.Object)) - } - params = append(params, kvBool("gen_ai.file.deleted", resp.FileDeleteResponse.Deleted)) - case resp.FileContentResponse != nil: - params = append(params, kvStr("gen_ai.file.file_id", resp.FileContentResponse.FileID)) - if resp.FileContentResponse.ContentType != "" { - params = append(params, kvStr("gen_ai.file.content_type", resp.FileContentResponse.ContentType)) - } - if len(resp.FileContentResponse.Content) > 0 { - params = append(params, kvInt("gen_ai.file.content_bytes", int64(len(resp.FileContentResponse.Content)))) - } - } +// convertSpanEvents converts Bifrost span events to OTEL events +func convertSpanEvents(events []schemas.SpanEvent) []*Event { + if len(events) == 0 { + return nil } - - // This is a fallback for worst case scenario where latency is not available - status := tracepb.Status_STATUS_CODE_OK - if bifrostErr != nil { - status = tracepb.Status_STATUS_CODE_ERROR - if bifrostErr.Error != nil { - if bifrostErr.Error.Type != nil { - params = append(params, kvStr("gen_ai.error.type", *bifrostErr.Error.Type)) - } - if bifrostErr.Error.Code != nil { - params = append(params, kvStr("gen_ai.error.code", *bifrostErr.Error.Code)) - } + otelEvents := make([]*Event, len(events)) + for i, event := range events { + otelEvents[i] = &Event{ + TimeUnixNano: uint64(event.Timestamp.UnixNano()), + Name: event.Name, + Attributes: convertAttributesToKeyValues(event.Attributes), } - params = append(params, kvStr("gen_ai.error", bifrostErr.Error.Message)) - } - // Adding request metadata to the span for backward compatibility - if virtualKeyID != "" { - params = append(params, kvStr("gen_ai.virtual_key_id", virtualKeyID)) - params = append(params, kvStr("gen_ai.virtual_key_name", virtualKeyName)) - } - if selectedKeyID != "" { - params = append(params, kvStr("gen_ai.selected_key_id", selectedKeyID)) - params = append(params, kvStr("gen_ai.selected_key_name", selectedKeyName)) - } - if teamID != "" { - params = append(params, kvStr("gen_ai.team_id", teamID)) - params = append(params, kvStr("gen_ai.team_name", teamName)) - } - if customerID != "" { - params = append(params, kvStr("gen_ai.customer_id", customerID)) - params = append(params, kvStr("gen_ai.customer_name", customerName)) } - params = append(params, kvInt("gen_ai.number_of_retries", int64(numberOfRetries))) - params = append(params, kvInt("gen_ai.fallback_index", int64(fallbackIndex))) - span.ScopeSpans[0].Spans[0].Attributes = append(span.ScopeSpans[0].Spans[0].Attributes, params...) - span.ScopeSpans[0].Spans[0].Status = &tracepb.Status{Code: status} - span.ScopeSpans[0].Spans[0].EndTimeUnixNano = uint64(timestamp.UnixNano()) - // Attaching virtual keys as resource attributes as well - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("virtual_key_id", virtualKeyID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("virtual_key_name", virtualKeyName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("selected_key_id", selectedKeyID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("selected_key_name", selectedKeyName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("team_id", teamID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("team_name", teamName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("customer_id", customerID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("customer_name", customerName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvInt("number_of_retries", int64(numberOfRetries))) - span.Resource.Attributes = append(span.Resource.Attributes, kvInt("fallback_index", int64(fallbackIndex))) - return span + return otelEvents } diff --git a/plugins/otel/main.go b/plugins/otel/main.go index 40d7ec243..31aab3dc0 100644 --- a/plugins/otel/main.go +++ b/plugins/otel/main.go @@ -6,29 +6,16 @@ import ( "fmt" "os" "strings" - "sync" - "time" "github.com/bytedance/sonic" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/maximhq/bifrost/framework/streaming" commonpb "go.opentelemetry.io/proto/otlp/common/v1" ) // logger is the logger for the OTEL plugin var logger schemas.Logger -// ContextKey is a custom type for context keys to prevent collisions -type ContextKey string - -// Context keys for otel plugin -const ( - TraceIDKey ContextKey = "plugin-otel-trace-id" - SpanIDKey ContextKey = "plugin-otel-span-id" -) - // OTELResponseAttributesEnvKey is the environment variable key for the OTEL resource attributes // We check if this is present in the environment variables and if so, we will use it to set the attributes for all spans at the resource level const OTELResponseAttributesEnvKey = "OTEL_RESOURCE_ATTRIBUTES" @@ -65,7 +52,9 @@ type Config struct { TLSCACert string `json:"tls_ca_cert"` } -// OtelPlugin is the plugin for OpenTelemetry +// OtelPlugin is the plugin for OpenTelemetry. +// It implements the ObservabilityPlugin interface to receive completed traces +// from the tracing middleware and forward them to an OTEL collector. type OtelPlugin struct { ctx context.Context cancel context.CancelFunc @@ -80,14 +69,9 @@ type OtelPlugin struct { attributesFromEnvironment []*commonpb.KeyValue - ongoingSpans *TTLSyncMap - client OtelClient pricingManager *modelcatalog.ModelCatalog - accumulator *streaming.Accumulator // Accumulator for streaming chunks - - emitWg sync.WaitGroup // Track in-flight emissions } // Init function for the OTEL plugin @@ -100,7 +84,7 @@ func Init(ctx context.Context, config *Config, _logger schemas.Logger, pricingMa logger.Warn("otel plugin requires model catalog to calculate cost, all cost calculations will be skipped.") } var err error - // If headers are present , and any of them start with env., we will replace the value with the environment variable + // If headers are present, and any of them start with env., we will replace the value with the environment variable if config.Headers != nil { for key, value := range config.Headers { if newValue, ok := strings.CutPrefix(value, "env."); ok { @@ -132,11 +116,8 @@ func Init(ctx context.Context, config *Config, _logger schemas.Logger, pricingMa url: config.CollectorURL, traceType: config.TraceType, headers: config.Headers, - ongoingSpans: NewTTLSyncMap(20*time.Minute, 1*time.Minute), protocol: config.Protocol, pricingManager: pricingManager, - accumulator: streaming.NewAccumulator(pricingManager, logger), - emitWg: sync.WaitGroup{}, bifrostVersion: bifrostVersion, attributesFromEnvironment: attributesFromEnvironment, } @@ -164,9 +145,9 @@ func (p *OtelPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (p *OtelPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (p *OtelPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // ValidateConfig function for the OTEL plugin @@ -205,139 +186,53 @@ func (p *OtelPlugin) ValidateConfig(config any) (*Config, error) { return &otelConfig, nil } -// PreHook function for the OTEL plugin -func (p *OtelPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - if p.client == nil { - logger.Warn("otel client is not initialized") - return req, nil, nil - } - traceIDValue := ctx.Value(schemas.BifrostContextKeyRequestID) - if traceIDValue == nil { - logger.Warn("trace id not found in context") - return req, nil, nil - } - traceID, ok := traceIDValue.(string) - if !ok { - logger.Warn("trace id not found in context") - return req, nil, nil - } - spanID := fmt.Sprintf("%s-root-span", traceID) - createdTimestamp := time.Now() - if bifrost.IsStreamRequestType(req.RequestType) { - p.accumulator.CreateStreamAccumulator(traceID, createdTimestamp) - } - p.ongoingSpans.Set(traceID, p.createResourceSpan(traceID, spanID, time.Now(), req)) +// PreHook is a no-op - tracing is handled via the Inject method. +// The OTEL plugin receives completed traces from TracingMiddleware. +func (p *OtelPlugin) PreHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { return req, nil, nil } -// PostHook function for the OTEL plugin -func (p *OtelPlugin) PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - traceIDValue := ctx.Value(schemas.BifrostContextKeyRequestID) - if traceIDValue == nil { - logger.Warn("trace id not found in context") - return resp, bifrostErr, nil +// PostHook is a no-op - tracing is handled via the Inject method. +// The OTEL plugin receives completed traces from TracingMiddleware. +func (p *OtelPlugin) PostHook(_ *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} + +// Inject receives a completed trace and sends it to the OTEL collector. +// Implements schemas.ObservabilityPlugin interface. +// This method is called asynchronously by TracingMiddleware after the response +// has been written to the client. +func (p *OtelPlugin) Inject(ctx context.Context, trace *schemas.Trace) error { + if trace == nil { + return nil } - traceID, ok := traceIDValue.(string) - if !ok { - logger.Warn("trace id not found in context") - return resp, bifrostErr, nil + if p.client == nil { + logger.Warn("otel client is not initialized") + return nil } - virtualKeyID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-virtual-key-id")) - virtualKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-virtual-key-name")) - - selectedKeyID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeySelectedKeyID) - selectedKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeySelectedKeyName) + // Convert schemas.Trace to OTEL ResourceSpan + resourceSpan := p.convertTraceToResourceSpan(trace) - numberOfRetries := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyNumberOfRetries) - fallbackIndex := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyFallbackIndex) - - teamID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-team-id")) - teamName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-team-name")) - customerID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-id")) - customerName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-name")) + // Emit to collector + if err := p.client.Emit(ctx, []*ResourceSpan{resourceSpan}); err != nil { + logger.Error("failed to emit trace %s: %v", trace.TraceID, err) + return err + } - // Track every PostHook emission, stream and non-stream. - p.emitWg.Add(1) - go func() { - defer p.emitWg.Done() - span, ok := p.ongoingSpans.Get(traceID) - if !ok { - logger.Warn("span not found in ongoing spans") - return - } - requestType, _, _ := bifrost.GetResponseFields(resp, bifrostErr) - if span, ok := span.(*ResourceSpan); ok { - // We handle streaming responses differently, we will use the accumulator to process the response and then emit the final response - if bifrost.IsStreamRequestType(requestType) { - streamResponse, err := p.accumulator.ProcessStreamingResponse(ctx, resp, bifrostErr) - if err != nil { - logger.Debug("failed to process streaming response: %v", err) - } - if streamResponse != nil && streamResponse.Type == streaming.StreamResponseTypeFinal { - defer p.ongoingSpans.Delete(traceID) - if err := p.client.Emit(p.ctx, []*ResourceSpan{completeResourceSpan( - span, - time.Now(), - streamResponse.ToBifrostResponse(), - bifrostErr, - p.pricingManager, - virtualKeyID, - virtualKeyName, - selectedKeyID, - selectedKeyName, - numberOfRetries, - fallbackIndex, - teamID, - teamName, - customerID, - customerName, - )}); err != nil { - logger.Error("failed to emit response span for request %s: %v", traceID, err) - } - } - return - } - defer p.ongoingSpans.Delete(traceID) - rs := completeResourceSpan( - span, - time.Now(), - resp, - bifrostErr, - p.pricingManager, - virtualKeyID, - virtualKeyName, - selectedKeyID, - selectedKeyName, - numberOfRetries, - fallbackIndex, - teamID, - teamName, - customerID, - customerName, - ) - if err := p.client.Emit(p.ctx, []*ResourceSpan{rs}); err != nil { - logger.Error("failed to emit response span for request %s: %v", traceID, err) - } - } - }() - return resp, bifrostErr, nil + return nil } // Cleanup function for the OTEL plugin func (p *OtelPlugin) Cleanup() error { - p.emitWg.Wait() if p.cancel != nil { p.cancel() } - if p.ongoingSpans != nil { - p.ongoingSpans.Stop() - } - if p.accumulator != nil { - p.accumulator.Cleanup() - } if p.client != nil { return p.client.Close() } return nil } + +// Compile-time check that OtelPlugin implements ObservabilityPlugin +var _ schemas.ObservabilityPlugin = (*OtelPlugin)(nil) diff --git a/plugins/otel/ttlsyncmap.go b/plugins/otel/ttlsyncmap.go deleted file mode 100644 index d54999d1b..000000000 --- a/plugins/otel/ttlsyncmap.go +++ /dev/null @@ -1,184 +0,0 @@ -package otel - -import ( - "sync" - "time" -) - -// TTLSyncMap is a thread-safe map with automatic cleanup of expired entries -type TTLSyncMap struct { - data sync.Map - ttl time.Duration - cleanupTicker *time.Ticker - stopCleanup chan struct{} - cleanupWg sync.WaitGroup - stopOnce sync.Once -} - -// entry stores the value along with its expiration time -type entry struct { - value interface{} - expiresAt time.Time -} - -// NewTTLSyncMap creates a new TTL sync map with the specified TTL and cleanup interval -// ttl: time to live for each entry -// cleanupInterval: how often to check for expired entries (should be <= ttl) -func NewTTLSyncMap(ttl time.Duration, cleanupInterval time.Duration) *TTLSyncMap { - if ttl <= 0 { - ttl = time.Minute - } - if cleanupInterval <= 0 { - cleanupInterval = ttl / 2 - if cleanupInterval <= 0 { - cleanupInterval = time.Minute - } - } - - m := &TTLSyncMap{ - ttl: ttl, - cleanupTicker: time.NewTicker(cleanupInterval), - stopCleanup: make(chan struct{}), - } - - // Start the cleanup goroutine - m.cleanupWg.Add(1) - go m.startCleanup() - - return m -} - -// Set stores a key-value pair with TTL -func (m *TTLSyncMap) Set(key, value interface{}) { - m.data.Store(key, &entry{ - value: value, - expiresAt: time.Now().Add(m.ttl), - }) -} - -// Get retrieves a value by key, returns (value, true) if found and not expired, -// (nil, false) otherwise -func (m *TTLSyncMap) Get(key interface{}) (interface{}, bool) { - val, ok := m.data.Load(key) - if !ok { - return nil, false - } - - e := val.(*entry) - if time.Now().After(e.expiresAt) { - // Entry has expired, delete it - m.data.Delete(key) - return nil, false - } - - return e.value, true -} - -// Delete removes a key-value pair from the map -func (m *TTLSyncMap) Delete(key interface{}) { - m.data.Delete(key) -} - -// Refresh updates the expiration time of an existing entry -func (m *TTLSyncMap) Refresh(key interface{}) bool { - val, ok := m.data.Load(key) - if !ok { - return false - } - e, _ := val.(*entry) - if e == nil || time.Now().After(e.expiresAt) { - m.data.Delete(key) - return false - } - m.data.Store(key, &entry{ - value: e.value, - expiresAt: time.Now().Add(m.ttl), - }) - return true -} - -// GetOrSet retrieves a value by key if it exists and is not expired, -// otherwise sets the new value and returns it -func (m *TTLSyncMap) GetOrSet(key, value interface{}) (actual interface{}, loaded bool) { - actual, loaded = m.Get(key) - if !loaded { - m.Set(key, value) - actual = value - } - return actual, loaded -} - -// Range calls f sequentially for each key and value present in the map. -// If f returns false, range stops the iteration. -// Only non-expired entries are included. -func (m *TTLSyncMap) Range(f func(key, value interface{}) bool) { - now := time.Now() - m.data.Range(func(key, val interface{}) bool { - e := val.(*entry) - if now.After(e.expiresAt) { - // Skip expired entry and delete it - m.data.Delete(key) - return true - } - return f(key, e.value) - }) -} - -// Len returns the number of non-expired entries in the map -func (m *TTLSyncMap) Len() int { - count := 0 - m.Range(func(_, _ interface{}) bool { - count++ - return true - }) - return count -} - -// startCleanup runs in a background goroutine to periodically remove expired entries -func (m *TTLSyncMap) startCleanup() { - defer m.cleanupWg.Done() - - for { - select { - case <-m.cleanupTicker.C: - m.cleanup() - case <-m.stopCleanup: - return - } - } -} - -// cleanup removes all expired entries from the map -func (m *TTLSyncMap) cleanup() { - now := time.Now() - m.data.Range(func(key, val interface{}) bool { - e := val.(*entry) - if now.After(e.expiresAt) { - m.data.Delete(key) - } - return true - }) - if m.Len() > 10000 { - logger.Warn("[otel] map cleanup done. current size: %d entries", m.Len()) - } else { - logger.Debug("[otel] map cleanup done. current size: %d entries", m.Len()) - } -} - -// Stop stops the cleanup goroutine and releases resources -// Call this when you're done with the map to prevent goroutine leaks -func (m *TTLSyncMap) Stop() { - m.stopOnce.Do(func() { - close(m.stopCleanup) - m.cleanupTicker.Stop() - m.cleanupWg.Wait() - }) -} - -// Clear removes all entries from the map -func (m *TTLSyncMap) Clear() { - m.data.Range(func(key, _ interface{}) bool { - m.data.Delete(key) - return true - }) -} diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index e69de29bb..218e0a55a 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -0,0 +1,18 @@ +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 0da4592ff..3c47b00b0 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -335,9 +335,9 @@ func (plugin *Plugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (plugin *Plugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (plugin *Plugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // PreHook is called before a request is processed by Bifrost. @@ -377,7 +377,7 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR ctx.SetValue(requestIDKey, requestID) ctx.SetValue(requestModelKey, model) ctx.SetValue(requestProviderKey, provider) - + performDirectSearch, performSemanticSearch := true, true if (*ctx).Value(CacheTypeKey) != nil { cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md index e69de29bb..218e0a55a 100644 --- a/plugins/telemetry/changelog.md +++ b/plugins/telemetry/changelog.md @@ -0,0 +1,18 @@ +- chore: upgraded versions of core to 1.3.0 and framework to 1.2.0 + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor replaced with HTTPTransportMiddleware** + + This plugin now implements `HTTPTransportMiddleware()` instead of `TransportInterceptor()` to comply with core v1.3.0. + + **What changed:** + - Old: `TransportInterceptor(ctx, url, headers, body) (headers, body, error)` + - New: `HTTPTransportMiddleware() BifrostHTTPMiddleware` + + **For plugin consumers:** + - If you import this plugin directly, no code changes are required + - If you extend this plugin, update your implementation to use `HTTPTransportMiddleware()` + - Recompile any code that depends on this plugin against core v1.3.0+ and framework v1.2.0+ + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for details. \ No newline at end of file diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index e227bac27..d80eadad5 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -276,9 +276,9 @@ func (p *PrometheusPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (p *PrometheusPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportMiddleware is not used for this plugin +func (p *PrometheusPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { + return nil } // PreHook records the start time of the request in the context. @@ -464,7 +464,7 @@ func (p *PrometheusPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas return result, bifrostErr, nil } -// PrometheusMiddleware wraps a FastHTTP handler to collect Prometheus metrics. +// HTTPMiddleware wraps a FastHTTP handler to collect Prometheus metrics. // It tracks: // - Total number of requests // - Request duration diff --git a/transports/bifrost-http/handlers/cache.go b/transports/bifrost-http/handlers/cache.go index a91d04aa5..df3ef62d6 100644 --- a/transports/bifrost-http/handlers/cache.go +++ b/transports/bifrost-http/handlers/cache.go @@ -23,7 +23,7 @@ func NewCacheHandler(plugin schemas.Plugin) *CacheHandler { } } -func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.DELETE("/api/cache/clear/{requestId}", lib.ChainMiddlewares(h.clearCache, middlewares...)) r.DELETE("/api/cache/clear-by-key/{cacheKey}", lib.ChainMiddlewares(h.clearCacheByKey, middlewares...)) } diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 9ca02170a..7031eb116 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -70,7 +70,7 @@ func NewConfigHandler(configManager ConfigManager, store *lib.Config) *ConfigHan // RegisterRoutes registers the configuration-related routes. // It adds the `PUT /api/config` endpoint. -func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/api/config", lib.ChainMiddlewares(h.getConfig, middlewares...)) r.PUT("/api/config", lib.ChainMiddlewares(h.updateConfig, middlewares...)) r.GET("/api/version", lib.ChainMiddlewares(h.getVersion, middlewares...)) diff --git a/transports/bifrost-http/handlers/devpprof.go b/transports/bifrost-http/handlers/devpprof.go new file mode 100644 index 000000000..eb8dc516b --- /dev/null +++ b/transports/bifrost-http/handlers/devpprof.go @@ -0,0 +1,376 @@ +package handlers + +import ( + "bytes" + "os" + "runtime" + "runtime/pprof" + "sort" + "sync" + "time" + + "github.com/fasthttp/router" + "github.com/google/pprof/profile" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +const ( + // Collection interval for metrics + metricsCollectionInterval = 10 * time.Second + // Number of data points to keep (5 minutes / 10 seconds = 30 points) + historySize = 30 + // Top allocations to return + topAllocationsCount = 5 +) + +// MemoryStats represents memory statistics at a point in time +type MemoryStats struct { + Alloc uint64 `json:"alloc"` + TotalAlloc uint64 `json:"total_alloc"` + HeapInuse uint64 `json:"heap_inuse"` + HeapObjects uint64 `json:"heap_objects"` + Sys uint64 `json:"sys"` +} + +// CPUStats represents CPU statistics +type CPUStats struct { + UsagePercent float64 `json:"usage_percent"` + UserTime float64 `json:"user_time"` + SystemTime float64 `json:"system_time"` +} + +// RuntimeStats represents runtime statistics +type RuntimeStats struct { + NumGoroutine int `json:"num_goroutine"` + NumGC uint32 `json:"num_gc"` + GCPauseNs uint64 `json:"gc_pause_ns"` + NumCPU int `json:"num_cpu"` + GOMAXPROCS int `json:"gomaxprocs"` +} + +// AllocationInfo represents a single allocation site +type AllocationInfo struct { + Function string `json:"function"` + File string `json:"file"` + Line int `json:"line"` + Bytes int64 `json:"bytes"` + Count int64 `json:"count"` +} + +// HistoryPoint represents a single point in the metrics history +type HistoryPoint struct { + Timestamp string `json:"timestamp"` + Alloc uint64 `json:"alloc"` + HeapInuse uint64 `json:"heap_inuse"` + Goroutines int `json:"goroutines"` + GCPauseNs uint64 `json:"gc_pause_ns"` + CPUPercent float64 `json:"cpu_percent"` +} + +// PprofData represents the complete pprof response +type PprofData struct { + Timestamp string `json:"timestamp"` + Memory MemoryStats `json:"memory"` + CPU CPUStats `json:"cpu"` + Runtime RuntimeStats `json:"runtime"` + TopAllocations []AllocationInfo `json:"top_allocations"` + History []HistoryPoint `json:"history"` +} + +// cpuSample holds a CPU time sample for calculating usage +type cpuSample struct { + timestamp time.Time + userTime time.Duration + systemTime time.Duration +} + +// MetricsCollector collects and stores runtime metrics +type MetricsCollector struct { + mu sync.RWMutex + history []HistoryPoint + stopCh chan struct{} + started bool + lastCPUSample cpuSample + currentCPU CPUStats +} + +// DevPprofHandler handles development profiling endpoints +type DevPprofHandler struct { + collector *MetricsCollector +} + +// Global collector instance +var globalCollector *MetricsCollector +var collectorOnce sync.Once + +// IsDevMode checks if dev mode is enabled via environment variable +func IsDevMode() bool { + return os.Getenv("BIFROST_UI_DEV") == "true" +} + +// getOrCreateCollector returns the global metrics collector, creating it if needed +func getOrCreateCollector() *MetricsCollector { + collectorOnce.Do(func() { + globalCollector = &MetricsCollector{ + history: make([]HistoryPoint, 0, historySize), + stopCh: make(chan struct{}), + } + }) + return globalCollector +} + +// NewDevPprofHandler creates a new dev pprof handler +func NewDevPprofHandler() *DevPprofHandler { + return &DevPprofHandler{ + collector: getOrCreateCollector(), + } +} + +// Start begins the background metrics collection +func (c *MetricsCollector) Start() { + c.mu.Lock() + if c.started { + c.mu.Unlock() + return + } + c.stopCh = make(chan struct{}) + c.started = true + c.mu.Unlock() + + go c.collectLoop() +} + +// Stop stops the background metrics collection +func (c *MetricsCollector) Stop() { + c.mu.Lock() + defer c.mu.Unlock() + if !c.started { + return + } + close(c.stopCh) + c.stopCh = nil + c.started = false +} + +func (c *MetricsCollector) collectLoop() { + // Initialize CPU sample + c.lastCPUSample = getCPUSample() + + // Wait a bit before first collection to get accurate CPU reading + time.Sleep(100 * time.Millisecond) + + // Collect immediately on start + c.collect() + + ticker := time.NewTicker(metricsCollectionInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.collect() + case <-c.stopCh: + return + } + } +} + +// calculateCPUUsage calculates CPU usage percentage between two samples +func calculateCPUUsage(prev, curr cpuSample, numCPU int) CPUStats { + elapsed := curr.timestamp.Sub(prev.timestamp) + if elapsed <= 0 { + return CPUStats{} + } + + userDelta := curr.userTime - prev.userTime + systemDelta := curr.systemTime - prev.systemTime + totalCPUTime := userDelta + systemDelta + + // Calculate percentage: (CPU time used / wall time) * 100 + // Normalized by number of CPUs to get 0-100% range + cpuPercent := (float64(totalCPUTime) / float64(elapsed)) * 100.0 + + // Cap at 100% * numCPU (in case of measurement errors) + maxPercent := float64(numCPU) * 100.0 + if cpuPercent > maxPercent { + cpuPercent = maxPercent + } + + return CPUStats{ + UsagePercent: cpuPercent, + UserTime: userDelta.Seconds(), + SystemTime: systemDelta.Seconds(), + } +} + +func (c *MetricsCollector) collect() { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + // Get current CPU sample and calculate usage + currentSample := getCPUSample() + cpuStats := calculateCPUUsage(c.lastCPUSample, currentSample, runtime.NumCPU()) + c.lastCPUSample = currentSample + + point := HistoryPoint{ + Timestamp: time.Now().Format(time.RFC3339), + Alloc: memStats.Alloc, + HeapInuse: memStats.HeapInuse, + Goroutines: runtime.NumGoroutine(), + GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256], + CPUPercent: cpuStats.UsagePercent, + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Store current CPU stats for API response + c.currentCPU = cpuStats + + // Append to history, maintaining ring buffer behavior + if len(c.history) >= historySize { + // Shift left by one and append + copy(c.history, c.history[1:]) + c.history[len(c.history)-1] = point + } else { + c.history = append(c.history, point) + } +} + +func (c *MetricsCollector) getHistory() []HistoryPoint { + c.mu.RLock() + defer c.mu.RUnlock() + + // Return a copy to avoid race conditions + result := make([]HistoryPoint, len(c.history)) + copy(result, c.history) + return result +} + +func (c *MetricsCollector) getCPUStats() CPUStats { + c.mu.RLock() + defer c.mu.RUnlock() + return c.currentCPU +} + +// getTopAllocations analyzes heap profile to find top allocation sites +func getTopAllocations() []AllocationInfo { + // Write heap profile to buffer + var buf bytes.Buffer + if err := pprof.WriteHeapProfile(&buf); err != nil { + return []AllocationInfo{} + } + + // Parse the protobuf profile + p, err := profile.Parse(&buf) + if err != nil { + return []AllocationInfo{} + } + + // Find the indices for alloc_objects and alloc_space sample types + var allocObjectsIdx, allocSpaceIdx int + for i, st := range p.SampleType { + switch st.Type { + case "alloc_objects": + allocObjectsIdx = i + case "alloc_space": + allocSpaceIdx = i + } + } + + // Aggregate allocations by function (top of stack = allocation site) + allocMap := make(map[string]*AllocationInfo) + + for _, sample := range p.Sample { + if len(sample.Location) == 0 { + continue + } + loc := sample.Location[0] // Top of stack = allocation site + if len(loc.Line) == 0 { + continue + } + line := loc.Line[0] + fn := line.Function + if fn == nil { + continue + } + + key := fn.Name + if existing, ok := allocMap[key]; ok { + existing.Bytes += sample.Value[allocSpaceIdx] + existing.Count += sample.Value[allocObjectsIdx] + } else { + allocMap[key] = &AllocationInfo{ + Function: fn.Name, + File: fn.Filename, + Line: int(line.Line), + Bytes: sample.Value[allocSpaceIdx], + Count: sample.Value[allocObjectsIdx], + } + } + } + + // Convert map to slice + allocations := make([]AllocationInfo, 0, len(allocMap)) + for _, alloc := range allocMap { + allocations = append(allocations, *alloc) + } + + // Sort by bytes descending + sort.Slice(allocations, func(i, j int) bool { + return allocations[i].Bytes > allocations[j].Bytes + }) + + // Return top N allocations + if len(allocations) > topAllocationsCount { + allocations = allocations[:topAllocationsCount] + } + + return allocations +} + +// RegisterRoutes registers the dev pprof routes +func (h *DevPprofHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + // Start the collector when routes are registered + h.collector.Start() + + r.GET("/api/dev/pprof", lib.ChainMiddlewares(h.getPprof, middlewares...)) +} + +// getPprof handles GET /api/dev/pprof +func (h *DevPprofHandler) getPprof(ctx *fasthttp.RequestCtx) { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + data := PprofData{ + Timestamp: time.Now().Format(time.RFC3339), + Memory: MemoryStats{ + Alloc: memStats.Alloc, + TotalAlloc: memStats.TotalAlloc, + HeapInuse: memStats.HeapInuse, + HeapObjects: memStats.HeapObjects, + Sys: memStats.Sys, + }, + CPU: h.collector.getCPUStats(), + Runtime: RuntimeStats{ + NumGoroutine: runtime.NumGoroutine(), + NumGC: memStats.NumGC, + GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256], + NumCPU: runtime.NumCPU(), + GOMAXPROCS: runtime.GOMAXPROCS(0), + }, + TopAllocations: getTopAllocations(), + History: h.collector.getHistory(), + } + + SendJSON(ctx, data) +} + +// Cleanup stops the metrics collector +func (h *DevPprofHandler) Cleanup() { + if h.collector != nil { + h.collector.Stop() + } +} diff --git a/transports/bifrost-http/handlers/devpprof_unix.go b/transports/bifrost-http/handlers/devpprof_unix.go new file mode 100644 index 000000000..5c9b72b2f --- /dev/null +++ b/transports/bifrost-http/handlers/devpprof_unix.go @@ -0,0 +1,27 @@ +//go:build !windows +// +build !windows + +package handlers + +import ( + "syscall" + "time" +) + +// getCPUSample gets the current CPU time sample using syscall +func getCPUSample() cpuSample { + var rusage syscall.Rusage + if err := syscall.Getrusage(syscall.RUSAGE_SELF, &rusage); err != nil { + return cpuSample{timestamp: time.Now()} + } + + userTime := time.Duration(rusage.Utime.Sec)*time.Second + time.Duration(rusage.Utime.Usec)*time.Microsecond + systemTime := time.Duration(rusage.Stime.Sec)*time.Second + time.Duration(rusage.Stime.Usec)*time.Microsecond + + return cpuSample{ + timestamp: time.Now(), + userTime: userTime, + systemTime: systemTime, + } +} + diff --git a/transports/bifrost-http/handlers/devpprof_windows.go b/transports/bifrost-http/handlers/devpprof_windows.go new file mode 100644 index 000000000..1e8a805c4 --- /dev/null +++ b/transports/bifrost-http/handlers/devpprof_windows.go @@ -0,0 +1,13 @@ +//go:build windows +// +build windows + +package handlers + +import "time" + +// getCPUSample returns a zeroed CPU sample on Windows +// Windows does not support syscall.Getrusage +func getCPUSample() cpuSample { + return cpuSample{timestamp: time.Now()} +} + diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index 05dd0d2a3..bb493ed2a 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -12,6 +12,7 @@ import ( "github.com/fasthttp/router" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/plugins/governance" @@ -154,7 +155,7 @@ type UpdateCustomerRequest struct { } // RegisterRoutes registers all governance-related routes for the new hierarchical system -func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Virtual Key CRUD operations r.GET("/api/governance/virtual-keys", lib.ChainMiddlewares(h.getVirtualKeys, middlewares...)) r.POST("/api/governance/virtual-keys", lib.ChainMiddlewares(h.createVirtualKey, middlewares...)) diff --git a/transports/bifrost-http/handlers/health.go b/transports/bifrost-http/handlers/health.go index 9c4103354..315315ccf 100644 --- a/transports/bifrost-http/handlers/health.go +++ b/transports/bifrost-http/handlers/health.go @@ -6,6 +6,7 @@ import ( "time" "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -23,7 +24,7 @@ func NewHealthHandler(config *lib.Config) *HealthHandler { } // RegisterRoutes registers the health-related routes. -func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/health", lib.ChainMiddlewares(h.getHealth, middlewares...)) } diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 9b95a33be..46f3c0756 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -412,7 +412,7 @@ const ( ) // RegisterRoutes registers all completion-related routes -func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Model endpoints r.GET("/v1/models", lib.ChainMiddlewares(h.listModels, middlewares...)) @@ -1195,11 +1195,25 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge return } + // Signal to tracing middleware that trace completion should be deferred + // The streaming callback will complete the trace after the stream ends + ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) + + // Get the trace completer function for use in the streaming callback + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + var includeEventType bool // Use streaming response writer ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { - defer w.Flush() + defer func() { + w.Flush() + // Complete the trace after streaming finishes + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + if traceCompleter != nil { + traceCompleter() + } + }() // Process streaming responses for chunk := range stream { diff --git a/transports/bifrost-http/handlers/integrations.go b/transports/bifrost-http/handlers/integrations.go index 60625475a..de5b9fbcf 100644 --- a/transports/bifrost-http/handlers/integrations.go +++ b/transports/bifrost-http/handlers/integrations.go @@ -5,6 +5,7 @@ package handlers import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" ) @@ -33,7 +34,7 @@ func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStor } // RegisterRoutes registers all integration routes for AI provider compatibility endpoints -func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Register routes for each integration extension for _, extension := range h.extensions { extension.RegisterRoutes(r, middlewares...) diff --git a/transports/bifrost-http/handlers/logging.go b/transports/bifrost-http/handlers/logging.go index cde2df171..7fe7f1574 100644 --- a/transports/bifrost-http/handlers/logging.go +++ b/transports/bifrost-http/handlers/logging.go @@ -39,7 +39,7 @@ func NewLoggingHandler(logManager logging.LogManager, redactedKeysManager Redact } // RegisterRoutes registers all logging-related routes -func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Log retrieval with filtering, search, and pagination r.GET("/api/logs", lib.ChainMiddlewares(h.getLogs, middlewares...)) r.GET("/api/logs/stats", lib.ChainMiddlewares(h.getLogsStats, middlewares...)) diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 06d7e9843..bc33d5bb3 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -40,7 +40,7 @@ func NewMCPHandler(mcpManager MCPManager, client *bifrost.Bifrost, store *lib.Co } // RegisterRoutes registers all MCP-related routes -func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // MCP tool execution endpoint r.POST("/v1/mcp/tool/execute", lib.ChainMiddlewares(h.executeTool, middlewares...)) r.GET("/api/mcp/clients", lib.ChainMiddlewares(h.getMCPClients, middlewares...)) @@ -113,7 +113,7 @@ func (h *MCPHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.store.GetHeaderFilterConfig()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/mcp_server.go b/transports/bifrost-http/handlers/mcpserver.go similarity index 98% rename from transports/bifrost-http/handlers/mcp_server.go rename to transports/bifrost-http/handlers/mcpserver.go index 1bdf5b6ca..b4c7d3268 100644 --- a/transports/bifrost-http/handlers/mcp_server.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -70,7 +70,7 @@ func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MC } // RegisterRoutes registers the MCP server route -func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE) r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...)) r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...)) @@ -85,7 +85,7 @@ func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderFilterConfig()) defer cancel() // Use mcp-go server to handle the request @@ -124,7 +124,7 @@ func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("Connection", "keep-alive") // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderFilterConfig()) // Use streaming response writer ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 544d84fb2..edc5ed38f 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -3,21 +3,22 @@ package handlers import ( "context" "encoding/base64" - "encoding/json" "fmt" "slices" "strings" + "sync/atomic" "time" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/encrypt" + "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) // CorsMiddleware handles CORS headers for localhost and configured allowed origins -func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { +func CorsMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { logger.Debug("CorsMiddleware: %s", string(ctx.Path())) @@ -45,104 +46,25 @@ func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { } } -// TransportInterceptorMiddleware collects all plugin interceptors and calls them one by one -func TransportInterceptorMiddleware(config *lib.Config, enterpriseOverrides lib.EnterpriseOverrides) lib.BifrostHTTPMiddleware { +// TransportInterceptorMiddleware collects all plugin HTTP transport middleware and chains them. +func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - // Get plugins from config - lock-free read plugins := config.GetLoadedPlugins() if len(plugins) == 0 { next(ctx) return } - if enterpriseOverrides == nil { - next(ctx) - return - } - // If governance plugin is not loaded, skip interception - hasGovernance := false - for _, p := range plugins { - if p.GetName() == enterpriseOverrides.GetGovernancePluginName() { - hasGovernance = true - break + pluginsMiddlewareChain := []schemas.BifrostHTTPMiddleware{} + for _, plugin := range plugins { + middleware := plugin.HTTPTransportMiddleware() + // Collect plugin HTTP transport middleware + if middleware == nil { + continue } + pluginsMiddlewareChain = append(pluginsMiddlewareChain, middleware) } - if !hasGovernance { - next(ctx) - return - } - - // Parse headers - headers := make(map[string]string) - originalHeaderNames := make([]string, 0, 16) - ctx.Request.Header.All()(func(key, value []byte) bool { - name := string(key) - headers[name] = string(value) - originalHeaderNames = append(originalHeaderNames, name) - - return true - }) - requestBody := make(map[string]any) - // Only read body if Content-Type is JSON to avoid consuming multipart/form-data streams - contentType := string(ctx.Request.Header.Peek("Content-Type")) - isJSONRequest := strings.HasPrefix(contentType, "application/json") - - // Only run interceptors for JSON requests - if isJSONRequest { - bodyBytes := ctx.Request.Body() - if len(bodyBytes) > 0 { - if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - // If body is not valid JSON, log warning and continue without interception - logger.Warn(fmt.Sprintf("[transportInterceptor]: Failed to unmarshal request body: %v, skipping interceptor", err)) - next(ctx) - return - } - } - for _, plugin := range plugins { - // Call TransportInterceptor on all plugins - pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) - modifiedHeaders, modifiedBody, err := plugin.TransportInterceptor(pluginCtx, string(ctx.Request.URI().RequestURI()), headers, requestBody) - cancel() - if err != nil { - logger.Warn(fmt.Sprintf("TransportInterceptor: Plugin '%s' returned error: %v", plugin.GetName(), err)) - // Continue with unmodified headers/body - continue - } - // Update headers and body with modifications - if modifiedHeaders != nil { - headers = modifiedHeaders - } - if modifiedBody != nil { - requestBody = modifiedBody - } - // Capturing plugin ctx values and putting them in the request context - for k, v := range pluginCtx.GetUserValues() { - ctx.SetUserValue(k, v) - } - } - - // Marshal the body back to JSON - updatedBody, err := json.Marshal(requestBody) - if err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("TransportInterceptor: Failed to marshal request body: %v", err)) - return - } - ctx.Request.SetBody(updatedBody) - - // Remove headers that were present originally but removed by plugins - for _, name := range originalHeaderNames { - if _, exists := headers[name]; !exists { - ctx.Request.Header.Del(name) - } - } - - // Set modified headers back on the request - for key, value := range headers { - ctx.Request.Header.Set(key, value) - } - } - - next(ctx) + lib.ChainMiddlewares(next, pluginsMiddlewareChain...)(ctx) } } } @@ -163,7 +85,7 @@ func validateSession(_ *fasthttp.RequestCtx, store configstore.ConfigStore, toke // This uses basic auth style username + password based authentication // No session tracking is used, so this is not suitable for production environments // These basicauth routes are only used for the dashboard and API routes -func AuthMiddleware(store configstore.ConfigStore) lib.BifrostHTTPMiddleware { +func AuthMiddleware(store configstore.ConfigStore) schemas.BifrostHTTPMiddleware { if store == nil { logger.Info("auth middleware is disabled because store is not present") return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { @@ -265,3 +187,143 @@ func AuthMiddleware(store configstore.ConfigStore) lib.BifrostHTTPMiddleware { } } } + +// TracingMiddleware creates distributed traces for requests and forwards completed traces +// to observability plugins after the response has been written. +// +// The middleware: +// 1. Extracts parent trace ID from incoming W3C traceparent header (if present) +// 2. Creates a new trace in the store (only the lightweight trace ID is stored in context) +// 3. Calls the next handler to process the request +// 4. After response is written, asynchronously completes the trace and forwards it to observability plugins +// +// This middleware should be placed early in the middleware chain to capture the full request lifecycle. +type TracingMiddleware struct { + tracer atomic.Pointer[tracing.Tracer] + obsPlugins atomic.Pointer[[]schemas.ObservabilityPlugin] +} + +// NewTracingMiddleware creates a new tracing middleware +func NewTracingMiddleware(tracer *tracing.Tracer, obsPlugins []schemas.ObservabilityPlugin) *TracingMiddleware { + tm := &TracingMiddleware{ + tracer: atomic.Pointer[tracing.Tracer]{}, + obsPlugins: atomic.Pointer[[]schemas.ObservabilityPlugin]{}, + } + tm.tracer.Store(tracer) + tm.obsPlugins.Store(&obsPlugins) + return tm +} + +// SetObservabilityPlugins sets the observability plugins for the tracing middleware +func (m *TracingMiddleware) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { + m.obsPlugins.Store(&obsPlugins) +} + +// SetTracer sets the tracer for the tracing middleware +func (m *TracingMiddleware) SetTracer(tracer *tracing.Tracer) { + m.tracer.Store(tracer) +} + +// Middleware returns the middleware function that creates distributed traces for requests and forwards completed traces +func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Skip if store is nil + if m.tracer.Load() == nil { + next(ctx) + return + } + // Extract parent trace ID from W3C headers (if present) + parentID := tracing.ExtractParentID(&ctx.Request.Header) + // Create trace in store - only ID returned (trace data stays in store) + traceID := m.tracer.Load().CreateTrace(parentID) + // Only trace ID goes into context (lightweight, no bloat) + ctx.SetUserValue(schemas.BifrostContextKeyTraceID, traceID) + // Store a trace completion callback for streaming handlers to use + ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func() { + m.completeAndFlushTrace(traceID) + }) + // Create root span for the HTTP request + spanCtx, rootSpan := m.tracer.Load().StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) + if rootSpan != nil { + m.tracer.Load().SetAttribute(rootSpan, "http.method", string(ctx.Method())) + m.tracer.Load().SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) + m.tracer.Load().SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) + // Set root span ID in context for child span creation + if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { + ctx.SetUserValue(schemas.BifrostContextKeySpanID, spanID) + } + } + defer func() { + // Record response status on the root span + if rootSpan != nil { + m.tracer.Load().SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) + if ctx.Response.StatusCode() >= 400 { + m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) + } else { + m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusOk, "") + } + } + // Check if trace completion is deferred (for streaming requests) + // If deferred, the streaming handler will complete the trace after stream ends + if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { + return + } + // After response written - async flush + m.completeAndFlushTrace(traceID) + }() + + next(ctx) + } + } +} + +// completeAndFlushTrace completes the trace and forwards it to observability plugins. +// This is called either by the middleware defer (for non-streaming) or by streaming handlers. +func (m *TracingMiddleware) completeAndFlushTrace(traceID string) { + go func() { + // Clean up the stream accumulator for this trace + // This must happen before EndTrace to ensure all accumulated data is available + m.tracer.Load().CleanupStreamAccumulator(traceID) + + // Get completed trace from store + completedTrace := m.tracer.Load().EndTrace(traceID) + if completedTrace == nil { + return + } + // Forward to all observability plugins + for _, plugin := range *m.obsPlugins.Load() { + if plugin == nil { + continue + } + // Call inject with a background context (request context is done) + if err := plugin.Inject(context.Background(), completedTrace); err != nil { + logger.Warn("observability plugin %s failed to inject trace: %v", plugin.GetName(), err) + } + } + // Return trace to pool for reuse + m.tracer.Load().ReleaseTrace(completedTrace) + }() +} + +// GetTracer returns the tracer instance for use by streaming handlers +func (m *TracingMiddleware) GetTracer() *tracing.Tracer { + return m.tracer.Load() +} + +// GetObservabilityPlugins filters and returns only observability plugins from a list of plugins. +// Uses Go type assertion to identify plugins implementing the ObservabilityPlugin interface. +func GetObservabilityPlugins(plugins []schemas.Plugin) []schemas.ObservabilityPlugin { + if len(plugins) == 0 { + return nil + } + + obsPlugins := make([]schemas.ObservabilityPlugin, 0) + for _, plugin := range plugins { + if obsPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { + obsPlugins = append(obsPlugins, obsPlugin) + } + } + + return obsPlugins +} diff --git a/transports/bifrost-http/handlers/middlewares_test.go b/transports/bifrost-http/handlers/middlewares_test.go index 395beb2ba..c99d77656 100644 --- a/transports/bifrost-http/handlers/middlewares_test.go +++ b/transports/bifrost-http/handlers/middlewares_test.go @@ -305,7 +305,7 @@ func TestChainMiddlewares_SingleMiddleware(t *testing.T) { middlewareCalled := false handlerCalled := false - middleware := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { middlewareCalled = true next(ctx) @@ -332,21 +332,21 @@ func TestChainMiddlewares_MultipleMiddlewares(t *testing.T) { ctx := &fasthttp.RequestCtx{} executionOrder := []int{} - middleware1 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 1) next(ctx) } }) - middleware2 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 2) next(ctx) } }) - middleware3 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 3) next(ctx) @@ -378,7 +378,7 @@ func TestChainMiddlewares_MultipleMiddlewares(t *testing.T) { func TestChainMiddlewares_MiddlewareCanModifyContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} - middleware := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { ctx.SetUserValue("test-key", "test-value") next(ctx) @@ -405,7 +405,7 @@ func TestChainMiddlewares_ShortCircuit(t *testing.T) { executionOrder := []int{} // First middleware - writes response and short-circuits by not calling next - middleware1 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 1) ctx.SetStatusCode(fasthttp.StatusUnauthorized) @@ -415,7 +415,7 @@ func TestChainMiddlewares_ShortCircuit(t *testing.T) { }) // Second middleware - should NOT execute when middleware1 short-circuits - middleware2 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 2) next(ctx) @@ -423,7 +423,7 @@ func TestChainMiddlewares_ShortCircuit(t *testing.T) { }) // Third middleware - should NOT execute when middleware1 short-circuits - middleware3 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 3) next(ctx) @@ -469,7 +469,7 @@ func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) { executionOrder := []int{} // First middleware - executes and calls next - middleware1 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 1) next(ctx) @@ -477,7 +477,7 @@ func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) { }) // Second middleware - writes response and short-circuits - middleware2 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 2) ctx.SetStatusCode(fasthttp.StatusUnauthorized) @@ -487,7 +487,7 @@ func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) { }) // Third middleware - should NOT execute - middleware3 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 3) next(ctx) diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go index 45e988409..af76d3541 100644 --- a/transports/bifrost-http/handlers/plugins.go +++ b/transports/bifrost-http/handlers/plugins.go @@ -50,7 +50,7 @@ type UpdatePluginRequest struct { } // RegisterRoutes registers the routes for the PluginsHandler -func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/api/plugins", lib.ChainMiddlewares(h.getPlugins, middlewares...)) r.GET("/api/plugins/{name}", lib.ChainMiddlewares(h.getPlugin, middlewares...)) r.POST("/api/plugins", lib.ChainMiddlewares(h.createPlugin, middlewares...)) diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 0ebcfe957..637483a2b 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -77,7 +77,7 @@ type ErrorResponse struct { } // RegisterRoutes registers all provider management routes -func (h *ProviderHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *ProviderHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Provider CRUD operations r.GET("/api/providers", lib.ChainMiddlewares(h.listProviders, middlewares...)) r.GET("/api/providers/{provider}", lib.ChainMiddlewares(h.getProvider, middlewares...)) diff --git a/transports/bifrost-http/handlers/session.go b/transports/bifrost-http/handlers/session.go index 646c28fd3..e72ed52c6 100644 --- a/transports/bifrost-http/handlers/session.go +++ b/transports/bifrost-http/handlers/session.go @@ -8,6 +8,7 @@ import ( "github.com/fasthttp/router" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" @@ -28,7 +29,7 @@ func NewSessionHandler(configStore configstore.ConfigStore) *SessionHandler { } // RegisterRoutes registers the session-related routes -func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.POST("/api/session/login", lib.ChainMiddlewares(h.login, middlewares...)) r.POST("/api/session/logout", lib.ChainMiddlewares(h.logout, middlewares...)) r.GET("/api/session/is-auth-enabled", lib.ChainMiddlewares(h.isAuthEnabled, middlewares...)) diff --git a/transports/bifrost-http/handlers/ui.go b/transports/bifrost-http/handlers/ui.go index cd42ad7dc..e0872e249 100644 --- a/transports/bifrost-http/handlers/ui.go +++ b/transports/bifrost-http/handlers/ui.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -25,7 +26,7 @@ func NewUIHandler(uiContent embed.FS) *UIHandler { } // RegisterRoutes registers the UI routes with the provided router. -func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { router.GET("/", lib.ChainMiddlewares(h.serveDashboard, middlewares...)) router.GET("/{filepath:*}", lib.ChainMiddlewares(h.serveDashboard, middlewares...)) } diff --git a/transports/bifrost-http/handlers/websocket.go b/transports/bifrost-http/handlers/websocket.go index eb4b05f5a..9fab07a00 100644 --- a/transports/bifrost-http/handlers/websocket.go +++ b/transports/bifrost-http/handlers/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/fasthttp/router" "github.com/fasthttp/websocket" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/transports/bifrost-http/lib" @@ -47,7 +48,7 @@ func NewWebSocketHandler(ctx context.Context, logManager logging.LogManager, all } // RegisterRoutes registers all WebSocket-related routes -func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...)) } diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 27cf1dca1..65cf3c0a9 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -69,7 +69,7 @@ import ( // ExtensionRouter defines the interface that all integration routers must implement // to register their routes with the main HTTP router. type ExtensionRouter interface { - RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) + RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) } // StreamingRequest interface for requests that support streaming @@ -322,7 +322,7 @@ func NewGenericRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, ro // RegisterRoutes registers all configured routes on the given fasthttp router. // This method implements the ExtensionRouter interface. -func (g *GenericRouter) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (g *GenericRouter) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { for _, route := range g.routes { // Validate route configuration at startup to fail fast method := strings.ToUpper(route.Method) @@ -510,14 +510,13 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // allows providers that check ctx.Done() to cancel early if needed. This is less critical than // streaming requests (where we actively detect write errors), but still provides a mechanism // for providers to respect cancellation. - requestCtx := *bifrostCtx - + var response interface{} var err error switch { case bifrostReq.ListModelsRequest != nil: - listModelsResponse, bifrostErr := g.client.ListModelsRequest(requestCtx, bifrostReq.ListModelsRequest) + listModelsResponse, bifrostErr := g.client.ListModelsRequest(ctx, bifrostReq.ListModelsRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -537,7 +536,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf response, err = config.ListModelsResponseConverter(bifrostCtx, listModelsResponse) case bifrostReq.TextCompletionRequest != nil: - textCompletionResponse, bifrostErr := g.client.TextCompletionRequest(requestCtx, bifrostReq.TextCompletionRequest) + textCompletionResponse, bifrostErr := g.client.TextCompletionRequest(ctx, bifrostReq.TextCompletionRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -560,7 +559,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.TextResponseConverter(bifrostCtx, textCompletionResponse) case bifrostReq.ChatRequest != nil: - chatResponse, bifrostErr := g.client.ChatCompletionRequest(requestCtx, bifrostReq.ChatRequest) + chatResponse, bifrostErr := g.client.ChatCompletionRequest(ctx, bifrostReq.ChatRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -583,7 +582,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.ChatResponseConverter(bifrostCtx, chatResponse) case bifrostReq.ResponsesRequest != nil: - responsesResponse, bifrostErr := g.client.ResponsesRequest(requestCtx, bifrostReq.ResponsesRequest) + responsesResponse, bifrostErr := g.client.ResponsesRequest(ctx, bifrostReq.ResponsesRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -606,7 +605,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.ResponsesResponseConverter(bifrostCtx, responsesResponse) case bifrostReq.EmbeddingRequest != nil: - embeddingResponse, bifrostErr := g.client.EmbeddingRequest(requestCtx, bifrostReq.EmbeddingRequest) + embeddingResponse, bifrostErr := g.client.EmbeddingRequest(ctx, bifrostReq.EmbeddingRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -629,7 +628,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.EmbeddingResponseConverter(bifrostCtx, embeddingResponse) case bifrostReq.SpeechRequest != nil: - speechResponse, bifrostErr := g.client.SpeechRequest(requestCtx, bifrostReq.SpeechRequest) + speechResponse, bifrostErr := g.client.SpeechRequest(ctx, bifrostReq.SpeechRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -663,7 +662,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf return } case bifrostReq.TranscriptionRequest != nil: - transcriptionResponse, bifrostErr := g.client.TranscriptionRequest(requestCtx, bifrostReq.TranscriptionRequest) + transcriptionResponse, bifrostErr := g.client.TranscriptionRequest(ctx, bifrostReq.TranscriptionRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -686,7 +685,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.TranscriptionResponseConverter(bifrostCtx, transcriptionResponse) case bifrostReq.CountTokensRequest != nil: - countTokensResponse, bifrostErr := g.client.CountTokensRequest(requestCtx, bifrostReq.CountTokensRequest) + countTokensResponse, bifrostErr := g.client.CountTokensRequest(ctx, bifrostReq.CountTokensRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -1135,9 +1134,23 @@ func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config // Bifrost handles cleanup internally for normal completion and errors, so we only cancel // upstream streams when write errors indicate the client has disconnected. func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, config RouteConfig, streamChan chan *schemas.BifrostStream, cancel context.CancelFunc) { + // Signal to tracing middleware that trace completion should be deferred + // The streaming callback will complete the trace after the stream ends + ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) + + // Get the trace completer function for use in the streaming callback + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + // Use streaming response writer ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { - defer w.Flush() + defer func() { + w.Flush() + // Complete the trace after streaming finishes + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + if traceCompleter != nil { + traceCompleter() + } + }() // Create encoder for AWS Event Stream if needed var eventStreamEncoder *eventstream.Encoder diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 3eaecd6e9..1e3fced4d 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -275,7 +275,7 @@ func (c *Config) initializeEncryption(configKey string) error { // - Case conversion for provider names (e.g., "OpenAI" -> "openai") // - In-memory storage for ultra-fast access during request processing // - Graceful handling of missing config files -func LoadConfig(ctx context.Context, configDirPath string, EnterpriseOverrides EnterpriseOverrides) (*Config, error) { +func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { // Initialize separate database connections for optimal performance at scale configFilePath := filepath.Join(configDirPath, "config.json") configDBPath := filepath.Join(configDirPath, "config.db") @@ -298,7 +298,7 @@ func LoadConfig(ctx context.Context, configDirPath string, EnterpriseOverrides E // If config file doesn't exist, we will directly use the config store (create one if it doesn't exist) if os.IsNotExist(err) { logger.Info("config file not found at path: %s, initializing with default values", absConfigFilePath) - return loadConfigFromDefaults(ctx, config, configDBPath, logsDBPath, EnterpriseOverrides) + return loadConfigFromDefaults(ctx, config, configDBPath, logsDBPath) } return nil, fmt.Errorf("failed to read config file: %w", err) } @@ -354,12 +354,12 @@ func LoadConfig(ctx context.Context, configDirPath string, EnterpriseOverrides E } // If config file exists, we will use it to bootstrap config tables logger.Info("loading configuration from: %s", absConfigFilePath) - return loadConfigFromFile(ctx, config, data, EnterpriseOverrides) + return loadConfigFromFile(ctx, config, data) } // loadConfigFromFile initializes configuration from a JSON config file. // It merges config file data with existing database config, with store taking priority. -func loadConfigFromFile(ctx context.Context, config *Config, data []byte, EnterpriseOverrides EnterpriseOverrides) (*Config, error) { +func loadConfigFromFile(ctx context.Context, config *Config, data []byte) (*Config, error) { var configData ConfigData if err := json.Unmarshal(data, &configData); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) @@ -400,7 +400,7 @@ func loadConfigFromFile(ctx context.Context, config *Config, data []byte, Enterp loadEnvKeysFromFile(ctx, config) // Initialize framework config and pricing manager - initFrameworkConfigFromFile(ctx, config, &configData, EnterpriseOverrides) + initFrameworkConfigFromFile(ctx, config, &configData) // Initialize encryption if err = initEncryptionFromFile(config, &configData); err != nil { @@ -1524,7 +1524,7 @@ func loadEnvKeysFromFile(ctx context.Context, config *Config) { } // initFrameworkConfigFromFile initializes framework config and pricing manager from file -func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData *ConfigData, EnterpriseOverrides EnterpriseOverrides) { +func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData *ConfigData) { pricingConfig := &modelcatalog.Config{} if config.ConfigStore != nil { frameworkConfig, err := config.ConfigStore.GetFrameworkConfig(ctx) @@ -1551,18 +1551,10 @@ func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData var pricingManager *modelcatalog.ModelCatalog var err error - // Check if EnterpriseOverrides is provided, otherwise use default initialization - if EnterpriseOverrides != nil { - pricingManager, err = EnterpriseOverrides.LoadPricingManager(ctx, pricingConfig, config.ConfigStore) - if err != nil { - logger.Warn("failed to load pricing manager: %v", err) - } - } else { - // Use default modelcatalog initialization when no enterprise overrides are provided - pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) - if err != nil { - logger.Warn("failed to initialize pricing manager: %v", err) - } + // Use default modelcatalog initialization when no enterprise overrides are provided + pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) + if err != nil { + logger.Warn("failed to initialize pricing manager: %v", err) } config.PricingManager = pricingManager } @@ -1595,7 +1587,7 @@ func initEncryptionFromFile(config *Config, configData *ConfigData) error { // loadConfigFromDefaults initializes configuration when no config file exists. // It creates a default SQLite config store and loads/creates default configurations. -func loadConfigFromDefaults(ctx context.Context, config *Config, configDBPath, logsDBPath string, EnterpriseOverrides EnterpriseOverrides) (*Config, error) { +func loadConfigFromDefaults(ctx context.Context, config *Config, configDBPath, logsDBPath string) (*Config, error) { var err error // Initialize default config store @@ -1642,7 +1634,7 @@ func loadConfigFromDefaults(ctx context.Context, config *Config, configDBPath, l } // Initialize framework config and pricing manager - if err = initDefaultFrameworkConfig(ctx, config, EnterpriseOverrides); err != nil { + if err = initDefaultFrameworkConfig(ctx, config); err != nil { return nil, err } @@ -1890,7 +1882,7 @@ func loadDefaultEnvKeys(ctx context.Context, config *Config) error { } // initDefaultFrameworkConfig initializes framework configuration and pricing manager -func initDefaultFrameworkConfig(ctx context.Context, config *Config, EnterpriseOverrides EnterpriseOverrides) error { +func initDefaultFrameworkConfig(ctx context.Context, config *Config) error { frameworkConfig, err := config.ConfigStore.GetFrameworkConfig(ctx) if err != nil { logger.Warn("failed to get framework config from store: %v", err) @@ -1936,19 +1928,10 @@ func initDefaultFrameworkConfig(ctx context.Context, config *Config, EnterpriseO // Initialize pricing manager var pricingManager *modelcatalog.ModelCatalog - - // Check if EnterpriseOverrides is provided, otherwise use default initialization - if EnterpriseOverrides != nil { - pricingManager, err = EnterpriseOverrides.LoadPricingManager(ctx, pricingConfig, config.ConfigStore) - if err != nil { - logger.Warn("failed to initialize pricing manager: %v", err) - } - } else { - // Use default modelcatalog initialization when no enterprise overrides are provided - pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) - if err != nil { - logger.Warn("failed to initialize pricing manager: %v", err) - } + // Use default modelcatalog initialization when no enterprise overrides are provided + pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) + if err != nil { + logger.Warn("failed to initialize pricing manager: %v", err) } config.PricingManager = pricingManager return nil diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 36e7476ed..c323b57ce 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -6672,7 +6672,7 @@ func TestSQLite_Provider_NewProviderFromFile(t *testing.T) { // Load config - this should create the provider in the DB ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -6710,7 +6710,7 @@ func TestSQLite_Provider_HashMatch_DBPreserved(t *testing.T) { // First load - creates provider in DB ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6721,7 +6721,7 @@ func TestSQLite_Provider_HashMatch_DBPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - should preserve DB config - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6753,7 +6753,7 @@ func TestSQLite_Provider_HashMismatch_FileSync(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6771,7 +6771,7 @@ func TestSQLite_Provider_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load with modified config.json - should sync from file - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6805,7 +6805,7 @@ func TestSQLite_Provider_DBOnlyProvider_Preserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6837,7 +6837,7 @@ func TestSQLite_Provider_DBOnlyProvider_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (no Anthropic) - should preserve DB-added provider - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6870,7 +6870,7 @@ func TestSQLite_Provider_RoundTrip(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6890,7 +6890,7 @@ func TestSQLite_Provider_RoundTrip(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - should preserve DB changes since hash matches - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -6932,7 +6932,7 @@ func TestSQLite_Key_NewKeyFromFile(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -6967,7 +6967,7 @@ func TestSQLite_Key_HashMatch_DBKeyPreserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -6978,7 +6978,7 @@ func TestSQLite_Key_HashMatch_DBKeyPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7017,7 +7017,7 @@ func TestSQLite_Key_DashboardAddedKey_Preserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7039,7 +7039,7 @@ func TestSQLite_Key_DashboardAddedKey_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (still has only file-key) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7084,7 +7084,7 @@ func TestSQLite_Key_KeyValueChange_Detected(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7109,7 +7109,7 @@ func TestSQLite_Key_KeyValueChange_Detected(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load with modified config - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7144,7 +7144,7 @@ func TestSQLite_Key_MultipleKeys_MergeLogic(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7171,7 +7171,7 @@ func TestSQLite_Key_MultipleKeys_MergeLogic(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (still has key-1 and key-2) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7214,7 +7214,7 @@ func TestSQLite_VirtualKey_NewFromFile(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -7254,7 +7254,7 @@ func TestSQLite_VirtualKey_HashMatch_DBPreserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7265,7 +7265,7 @@ func TestSQLite_VirtualKey_HashMatch_DBPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7295,7 +7295,7 @@ func TestSQLite_VirtualKey_HashMismatch_FileSync(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7322,7 +7322,7 @@ func TestSQLite_VirtualKey_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load with modified config - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7357,7 +7357,7 @@ func TestSQLite_VirtualKey_DBOnlyVK_Preserved(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7379,7 +7379,7 @@ func TestSQLite_VirtualKey_DBOnlyVK_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json (only has vk-file) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7432,7 +7432,7 @@ func TestSQLite_VirtualKey_WithProviderConfigs(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -7493,7 +7493,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigs(t *testing.T) { // First load - bootstrap path ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7524,7 +7524,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigs(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - merge path (this is where the bug is) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7595,7 +7595,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigKeys(t *testing.T) { // First load - bootstrap path (creates provider with key in DB) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -7640,7 +7640,7 @@ func TestSQLite_VirtualKey_MergePath_WithProviderConfigKeys(t *testing.T) { // Second load - merge path // BEFORE FIX: This would fail because GORM tries to INSERT the key again // AFTER FIX: CreateVirtualKeyProviderConfig uses Append() to associate existing keys - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -7786,7 +7786,7 @@ func TestSQLite_VKProviderConfig_NewConfig(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -7858,7 +7858,7 @@ func TestSQLite_VKProviderConfig_KeyReference(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -8166,7 +8166,7 @@ func TestSQLite_FullLifecycle_InitialLoad(t *testing.T) { // Load config ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -8225,7 +8225,7 @@ func TestSQLite_FullLifecycle_SecondLoadNoChanges(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8238,7 +8238,7 @@ func TestSQLite_FullLifecycle_SecondLoadNoChanges(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with same config.json - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8280,7 +8280,7 @@ func TestSQLite_FullLifecycle_FileChange_Selective(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8313,7 +8313,7 @@ func TestSQLite_FullLifecycle_FileChange_Selective(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8369,7 +8369,7 @@ func TestSQLite_FullLifecycle_DashboardEdits_ThenFileUnchanged(t *testing.T) { // First load ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8407,7 +8407,7 @@ func TestSQLite_FullLifecycle_DashboardEdits_ThenFileUnchanged(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load with SAME config.json (unchanged) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8626,7 +8626,7 @@ func TestSQLite_VirtualKey_WithMCPConfigs(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - creates VK - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8715,7 +8715,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8792,7 +8792,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - should trigger reconciliation - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -8894,7 +8894,7 @@ func TestSQLite_VirtualKey_DashboardProviderConfig_PreservedOnFileChange(t *test createConfigFile(t, tempDir, configData) // Step 2: First load - bootstrap path - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -8958,7 +8958,7 @@ func TestSQLite_VirtualKey_DashboardProviderConfig_PreservedOnFileChange(t *test createConfigFile(t, tempDir, configData2) // Step 5: Second load - merge path with hash mismatch - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9043,7 +9043,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_PreservedOnFileChange(t *testing.T createConfigFile(t, tempDir, configData) // Step 2: First load - bootstrap path - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9134,7 +9134,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_PreservedOnFileChange(t *testing.T createConfigFile(t, tempDir, configData2) // Step 5: Second load - merge path with hash mismatch - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9214,7 +9214,7 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9253,7 +9253,7 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - should add MCP configs - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9286,7 +9286,7 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { // Third load - mcpClient2 config should be PRESERVED (not deleted) // This protects dashboard-added configs from accidental deletion - config3, err := LoadConfig(ctx, tempDir, nil) + config3, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Third LoadConfig failed: %v", err) } @@ -9337,7 +9337,7 @@ func TestSQLite_VKMCPConfig_UpdateTools(t *testing.T) { createConfigFile(t, tempDir, configData) // First load - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9387,7 +9387,7 @@ func TestSQLite_VKMCPConfig_UpdateTools(t *testing.T) { createConfigFile(t, tempDir, configData2) // Second load - should update tools - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -9431,7 +9431,7 @@ func TestSQLite_VK_ProviderAndMCPConfigs_Combined(t *testing.T) { createConfigFile(t, tempDir, configData) // First load to set up DB - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -9469,7 +9469,7 @@ func TestSQLite_VK_ProviderAndMCPConfigs_Combined(t *testing.T) { createConfigFile(t, tempDir, configData2) // Load config - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10262,7 +10262,7 @@ func TestSQLite_Budget_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10305,7 +10305,7 @@ func TestSQLite_Budget_HashMatch_DBPreserved(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10316,7 +10316,7 @@ func TestSQLite_Budget_HashMatch_DBPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Second load - same config - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10344,7 +10344,7 @@ func TestSQLite_Budget_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10355,7 +10355,7 @@ func TestSQLite_Budget_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) // Second load - should sync from file - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10383,7 +10383,7 @@ func TestSQLite_Budget_DBOnly_Preserved(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10400,7 +10400,7 @@ func TestSQLite_Budget_DBOnly_Preserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Reload - dashboard budget should be preserved - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10534,7 +10534,7 @@ func TestSQLite_RateLimit_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10572,7 +10572,7 @@ func TestSQLite_RateLimit_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10584,7 +10584,7 @@ func TestSQLite_RateLimit_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) // Second load - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10677,7 +10677,7 @@ func TestSQLite_Customer_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10711,7 +10711,7 @@ func TestSQLite_Customer_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10721,7 +10721,7 @@ func TestSQLite_Customer_HashMismatch_FileSync(t *testing.T) { configData.Governance.Customers[0].Name = "Updated Customer" createConfigFile(t, tempDir, configData) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -10852,7 +10852,7 @@ func TestSQLite_Team_NewFromFile(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config, err := LoadConfig(ctx, tempDir, nil) + config, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("LoadConfig failed: %v", err) } @@ -10887,7 +10887,7 @@ func TestSQLite_Team_HashMismatch_FileSync(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -10897,7 +10897,7 @@ func TestSQLite_Team_HashMismatch_FileSync(t *testing.T) { configData.Governance.Teams[0].Name = "Updated Team" createConfigFile(t, tempDir, configData) - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -11276,7 +11276,7 @@ func TestSQLite_Governance_FullReconciliation(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -11306,7 +11306,7 @@ func TestSQLite_Governance_FullReconciliation(t *testing.T) { createConfigFile(t, tempDir, configData) // Reload and verify all entities are updated - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } @@ -11341,7 +11341,7 @@ func TestSQLite_Governance_DBOnly_AllPreserved(t *testing.T) { createConfigFile(t, tempDir, configData) ctx := context.Background() - config1, err := LoadConfig(ctx, tempDir, nil) + config1, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("First LoadConfig failed: %v", err) } @@ -11377,7 +11377,7 @@ func TestSQLite_Governance_DBOnly_AllPreserved(t *testing.T) { config1.ConfigStore.Close(ctx) // Reload - all dashboard entities should be preserved - config2, err := LoadConfig(ctx, tempDir, nil) + config2, err := LoadConfig(ctx, tempDir) if err != nil { t.Fatalf("Second LoadConfig failed: %v", err) } diff --git a/transports/bifrost-http/lib/middleware.go b/transports/bifrost-http/lib/middleware.go index c1657c6aa..6ff034644 100644 --- a/transports/bifrost-http/lib/middleware.go +++ b/transports/bifrost-http/lib/middleware.go @@ -1,15 +1,14 @@ package lib -import "github.com/valyala/fasthttp" - -// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport -// It follows the standard pattern: receives the next handler and returns a new handler -type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler +import ( + "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) // ChainMiddlewares chains multiple middlewares together // Middlewares are applied in order: the first middleware wraps the second, etc. // This allows earlier middlewares to short-circuit by not calling next(ctx) -func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...BifrostHTTPMiddleware) fasthttp.RequestHandler { +func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...schemas.BifrostHTTPMiddleware) fasthttp.RequestHandler { // If no middlewares, return the original handler if len(middlewares) == 0 { return handler diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 014ac23ca..118d182cd 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -26,6 +26,7 @@ import ( "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/modelcatalog" dynamicPlugins "github.com/maximhq/bifrost/framework/plugins" + "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" @@ -109,6 +110,7 @@ type BifrostHTTPServer struct { WebSocketHandler *handlers.WebSocketHandler LogsCleaner *logstore.LogsCleaner MCPServerHandler *handlers.MCPServerHandler + devPprofHandler *handlers.DevPprofHandler } var logger schemas.Logger @@ -211,7 +213,7 @@ func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProv } // LoadPlugin loads a plugin by name and returns it as type T. -func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string, pluginConfig any, bifrostConfig *lib.Config, EnterpriseOverrides lib.EnterpriseOverrides) (T, error) { +func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string, pluginConfig any, bifrostConfig *lib.Config) (T, error) { var zero T if path != nil { logger.Info("loading dynamic plugin %s from path %s", name, *path) @@ -262,7 +264,7 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string return p, nil } return zero, fmt.Errorf("logging plugin type mismatch") - case EnterpriseOverrides.GetGovernancePluginName(): + case governance.PluginName: governanceConfig, err := MarshalPluginConfig[governance.Config](pluginConfig) if err != nil { return zero, fmt.Errorf("failed to marshal governance plugin config: %v", err) @@ -323,12 +325,12 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string } // LoadPlugins loads the plugins for the server. -func LoadPlugins(ctx context.Context, config *lib.Config, EnterpriseOverrides lib.EnterpriseOverrides) ([]schemas.Plugin, []schemas.PluginStatus, error) { +func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []schemas.PluginStatus, error) { var err error pluginStatus := []schemas.PluginStatus{} plugins := []schemas.Plugin{} // Initialize telemetry plugin - promPlugin, err := LoadPlugin[*telemetry.PrometheusPlugin](ctx, telemetry.PluginName, nil, nil, config, EnterpriseOverrides) + promPlugin, err := LoadPlugin[*telemetry.PrometheusPlugin](ctx, telemetry.PluginName, nil, nil, config) if err != nil { logger.Error("failed to initialize telemetry plugin: %v", err) pluginStatus = append(pluginStatus, schemas.PluginStatus{ @@ -350,7 +352,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config, EnterpriseOverrides li // Use dedicated logs database with high-scale optimizations loggingPlugin, err = LoadPlugin[*logging.LoggerPlugin](ctx, logging.PluginName, nil, &logging.Config{ DisableContentLogging: &config.ClientConfig.DisableContentLogging, - }, config, EnterpriseOverrides) + }, config) if err != nil { logger.Error("failed to initialize logging plugin: %v", err) pluginStatus = append(pluginStatus, schemas.PluginStatus{ @@ -376,32 +378,34 @@ func LoadPlugins(ctx context.Context, config *lib.Config, EnterpriseOverrides li // Initializing governance plugin if config.ClientConfig.EnableGovernance { // Initialize governance plugin - governancePlugin, err := EnterpriseOverrides.LoadGovernancePlugin(ctx, config) + governancePlugin, err := LoadPlugin[*governance.GovernancePlugin](ctx, governance.PluginName, nil, &governance.Config{ + IsVkMandatory: &config.ClientConfig.EnforceGovernanceHeader, + }, config) if err != nil { logger.Error("failed to initialize governance plugin: %s", err.Error()) pluginStatus = append(pluginStatus, schemas.PluginStatus{ - Name: EnterpriseOverrides.GetGovernancePluginName(), + Name: governance.PluginName, Status: schemas.PluginStatusError, Logs: []string{fmt.Sprintf("error initializing governance plugin %v", err)}, }) } else if governancePlugin != nil { plugins = append(plugins, governancePlugin) pluginStatus = append(pluginStatus, schemas.PluginStatus{ - Name: EnterpriseOverrides.GetGovernancePluginName(), + Name: governance.PluginName, Status: schemas.PluginStatusActive, Logs: []string{"governance plugin initialized successfully"}, }) } } else { pluginStatus = append(pluginStatus, schemas.PluginStatus{ - Name: EnterpriseOverrides.GetGovernancePluginName(), + Name: governance.PluginName, Status: schemas.PluginStatusDisabled, Logs: []string{"governance plugin disabled"}, }) } for _, plugin := range config.PluginConfigs { // Skip built-in plugins that are already handled above - if plugin.Name == telemetry.PluginName || plugin.Name == logging.PluginName || plugin.Name == EnterpriseOverrides.GetGovernancePluginName() { + if plugin.Name == telemetry.PluginName || plugin.Name == logging.PluginName || plugin.Name == governance.PluginName { continue } if !plugin.Enabled { @@ -412,7 +416,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config, EnterpriseOverrides li }) continue } - pluginInstance, err := LoadPlugin[schemas.Plugin](ctx, plugin.Name, plugin.Path, plugin.Config, config, EnterpriseOverrides) + pluginInstance, err := LoadPlugin[schemas.Plugin](ctx, plugin.Name, plugin.Path, plugin.Config, config) if err != nil { if slices.Contains(enterprisePlugins, plugin.Name) { continue @@ -507,7 +511,7 @@ func (s *BifrostHTTPServer) GetAvailableMCPTools(ctx context.Context) []schemas. // and returns the BaseGovernancePlugin implementation or an error. func (s *BifrostHTTPServer) getGovernancePlugin() (governance.BaseGovernancePlugin, error) { s.PluginsMutex.RLock() - plugin, err := FindPluginByName[schemas.Plugin](s.Plugins, s.GetGovernancePluginName()) + plugin, err := FindPluginByName[schemas.Plugin](s.Plugins, governance.PluginName) s.PluginsMutex.RUnlock() if err != nil { return nil, err @@ -654,7 +658,7 @@ func (s *BifrostHTTPServer) RemoveCustomer(ctx context.Context, id string) error // GetGovernanceData returns the governance data func (s *BifrostHTTPServer) GetGovernanceData() *governance.GovernanceData { s.PluginsMutex.RLock() - governancePlugin, err := FindPluginByName[schemas.Plugin](s.Plugins, s.GetGovernancePluginName()) + governancePlugin, err := FindPluginByName[schemas.Plugin](s.Plugins, governance.PluginName) s.PluginsMutex.RUnlock() if err != nil { return nil @@ -822,7 +826,7 @@ func (s *BifrostHTTPServer) SyncLoadedPlugin(ctx context.Context, configName str // Uses atomic CompareAndSwap with retry loop to handle concurrent updates safely. func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error { logger.Debug("reloading plugin %s", name) - newPlugin, err := LoadPlugin[schemas.Plugin](ctx, name, path, pluginConfig, s.Config, s) + newPlugin, err := LoadPlugin[schemas.Plugin](ctx, name, path, pluginConfig, s.Config) if err != nil { s.UpdatePluginStatus(name, schemas.PluginStatusError, []string{fmt.Sprintf("error loading plugin %s: %v", name, err)}) return err @@ -955,7 +959,7 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error } // RegisterInferenceRoutes initializes the routes for the inference handler -func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlewares ...lib.BifrostHTTPMiddleware) error { +func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlewares ...schemas.BifrostHTTPMiddleware) error { inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config) integrationHandler := handlers.NewIntegrationHandler(s.Client, s.Config) @@ -965,7 +969,7 @@ func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlew } // RegisterAPIRoutes initializes the routes for the Bifrost HTTP server. -func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks ServerCallbacks, EnterpriseOverrides lib.EnterpriseOverrides, middlewares ...lib.BifrostHTTPMiddleware) error { +func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks ServerCallbacks, middlewares ...schemas.BifrostHTTPMiddleware) error { var err error // Initializing plugin specific handlers var loggingHandler *handlers.LoggingHandler @@ -974,7 +978,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser loggingHandler = handlers.NewLoggingHandler(loggerPlugin.GetPluginLogManager(), s) } var governanceHandler *handlers.GovernanceHandler - governancePlugin, _ := FindPluginByName[schemas.Plugin](s.Plugins, EnterpriseOverrides.GetGovernancePluginName()) + governancePlugin, _ := FindPluginByName[schemas.Plugin](s.Plugins, governance.PluginName) if governancePlugin != nil { governanceHandler, err = handlers.NewGovernanceHandler(callbacks, s.Config.ConfigStore) if err != nil { @@ -1034,6 +1038,12 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser if s.WebSocketHandler != nil { s.WebSocketHandler.RegisterRoutes(s.Router, middlewares...) } + // Register dev pprof handler only in dev mode + if handlers.IsDevMode() { + logger.Info("dev mode enabled, registering pprof endpoints") + s.devPprofHandler = handlers.NewDevPprofHandler() + s.devPprofHandler.RegisterRoutes(s.Router, middlewares...) + } // Add Prometheus /metrics endpoint prometheusPlugin, err := FindPluginByName[*telemetry.PrometheusPlugin](s.Plugins, telemetry.PluginName) if err == nil && prometheusPlugin.GetRegistry() != nil { @@ -1051,7 +1061,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser } // RegisterUIRoutes registers the UI handler with the specified router -func (s *BifrostHTTPServer) RegisterUIRoutes(middlewares ...lib.BifrostHTTPMiddleware) { +func (s *BifrostHTTPServer) RegisterUIRoutes(middlewares ...schemas.BifrostHTTPMiddleware) { // WARNING: This UI handler needs to be registered after all the other handlers handlers.NewUIHandler(s.UIContent).RegisterRoutes(s.Router, middlewares...) } @@ -1082,26 +1092,6 @@ func (s *BifrostHTTPServer) GetAllRedactedVirtualKeys(ctx context.Context, ids [ return virtualKeys } -func (s *BifrostHTTPServer) GetGovernancePluginName() string { - if s.OSSToEnterprisePluginNameOverrides != nil { - if name, ok := s.OSSToEnterprisePluginNameOverrides[governance.PluginName]; ok && name != "" { - return name - } - } - return governance.PluginName -} - -func (s *BifrostHTTPServer) LoadGovernancePlugin(ctx context.Context, config *lib.Config) (schemas.Plugin, error) { - governancePlugin, err := LoadPlugin[*governance.GovernancePlugin](ctx, governance.PluginName, nil, &governance.Config{ - IsVkMandatory: &config.ClientConfig.EnforceGovernanceHeader, - }, config, s) - - if err != nil { - return nil, fmt.Errorf("failed to initialize governance plugin: %v", err) - } - return governancePlugin, nil -} - func (s *BifrostHTTPServer) LoadPricingManager(ctx context.Context, pricingConfig *modelcatalog.Config, configStore configstore.ConfigStore) (*modelcatalog.ModelCatalog, error) { pricingManager, err := modelcatalog.Init(ctx, pricingConfig, configStore, nil, logger) if err != nil { @@ -1111,8 +1101,8 @@ func (s *BifrostHTTPServer) LoadPricingManager(ctx context.Context, pricingConfi } // PrepareCommonMiddlewares gets the common middlewares for the Bifrost HTTP server -func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []lib.BifrostHTTPMiddleware { - commonMiddlewares := []lib.BifrostHTTPMiddleware{} +func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []schemas.BifrostHTTPMiddleware { + commonMiddlewares := []schemas.BifrostHTTPMiddleware{} // Preparing middlewares // Initializing prometheus plugin prometheusPlugin, err := FindPluginByName[*telemetry.PrometheusPlugin](s.Plugins, telemetry.PluginName) @@ -1147,7 +1137,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { return fmt.Errorf("failed to create app directory %s: %v", configDir, err) } // Initialize high-performance configuration store with dedicated database - s.Config, err = lib.LoadConfig(ctx, configDir, s) + s.Config, err = lib.LoadConfig(ctx, configDir) if err != nil { return fmt.Errorf("failed to load config %v", err) } @@ -1186,7 +1176,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Load plugins s.pluginStatusMutex.Lock() defer s.pluginStatusMutex.Unlock() - s.Plugins, s.pluginStatus, err = LoadPlugins(ctx, s.Config, s) + s.Plugins, s.pluginStatus, err = LoadPlugins(ctx, s.Config) if err != nil { return fmt.Errorf("failed to load plugins %v", err) } @@ -1246,7 +1236,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { apiMiddlewares = append(apiMiddlewares, handlers.AuthMiddleware(s.Config.ConfigStore)) } // Register routes - err = s.RegisterAPIRoutes(s.ctx, s, s, apiMiddlewares...) + err = s.RegisterAPIRoutes(s.ctx, s, apiMiddlewares...) if err != nil { return fmt.Errorf("failed to initialize routes: %v", err) } @@ -1255,7 +1245,31 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { inferenceMiddlewares = append(inferenceMiddlewares, handlers.AuthMiddleware(s.Config.ConfigStore)) } // Registering inference middlewares - inferenceMiddlewares = append([]lib.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config, s)}, inferenceMiddlewares...) + inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) + // Curating observability plugins + observabilityPlugins := []schemas.ObservabilityPlugin{} + for _, plugin := range s.Plugins { + if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { + observabilityPlugins = append(observabilityPlugins, observabilityPlugin) + } + } + // Check if logging plugin is enabled + loggingPluginEnabled := false + if _, err := FindPluginByName[*logging.LoggerPlugin](s.Plugins, logging.PluginName); err == nil { + loggingPluginEnabled = true + } + // Initialize tracer when observability plugins OR logging plugin is enabled + // This enables the central streaming accumulator for both use cases + if len(observabilityPlugins) > 0 || loggingPluginEnabled { + // Initializing tracer with embedded streaming accumulator + traceStore := tracing.NewTraceStore(60*time.Minute, logger) + tracer := tracing.NewTracer(traceStore, s.Config.PricingManager, logger) + s.Client.SetTracer(tracer) + // Always add tracing middleware when tracer is enabled - it creates traces and sets traceID in context + // The observability plugins are optional (can be empty if only logging is enabled) + tracingMiddleware := handlers.NewTracingMiddleware(tracer, observabilityPlugins) + inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{tracingMiddleware.Middleware()}, inferenceMiddlewares...) + } err = s.RegisterInferenceRoutes(s.ctx, inferenceMiddlewares...) if err != nil { return fmt.Errorf("failed to initialize inference routes: %v", err) @@ -1274,6 +1288,12 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Start starts the HTTP server at the specified host and port // Also watches signals and errors func (s *BifrostHTTPServer) Start() error { + // Printing plugin status in a table + s.pluginStatusMutex.RLock() + for _, pluginStatus := range s.pluginStatus { + logger.Info("plugin status: %s - %s", pluginStatus.Name, pluginStatus.Status) + } + s.pluginStatusMutex.RUnlock() // Create channels for signal and error handling sigChan := make(chan os.Signal, 1) errChan := make(chan error, 1) @@ -1327,6 +1347,10 @@ func (s *BifrostHTTPServer) Start() error { logger.Info("stopping log retention cleaner...") s.LogsCleaner.StopCleanupRoutine() } + if s.devPprofHandler != nil { + logger.Info("stopping dev pprof handler...") + s.devPprofHandler.Cleanup() + } if s.Config != nil && s.Config.LogsStore != nil { s.Config.LogsStore.Close(shutdownCtx) } diff --git a/transports/changelog.md b/transports/changelog.md index 85c47faef..d7224bb2b 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1,4 +1,30 @@ -- feat: added code mode to mcp +- refactor: governance plugin refactored for extensibility and optimization +- feat: new MCP gateway (server including) along with code mode - feat: added health monitoring to mcp - feat: added responses format tool execution support to mcp -- refactor: governance plugin refactored for extensibility and optimization \ No newline at end of file +- feat: new e2e tracing +- fix: gemini thought signature handling in multi-turn conversations + +### BREAKING CHANGES + +- **Plugin Interface: TransportInterceptor removed, replaced with HTTPTransportMiddleware** + + The `TransportInterceptor` function has been removed from the plugin interface. Plugins using HTTP transport interception must migrate to `HTTPTransportMiddleware`. + + **Migration summary:** + ``` + // v1.3.x (removed) + TransportInterceptor(ctx *BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + + // v1.4.x+ (new) + HTTPTransportMiddleware() BifrostHTTPMiddleware + // where BifrostHTTPMiddleware = func(next fasthttp.RequestHandler) fasthttp.RequestHandler + ``` + + **Key API changes:** + - Function renamed: `TransportInterceptor` -> `HTTPTransportMiddleware` + - Signature changed: Now returns a middleware wrapper instead of accepting/returning header/body maps + - Added dependency: Requires `github.com/valyala/fasthttp` import + - Flow control: Must explicitly call `next(ctx)` to continue the chain + + See [Plugin Migration Guide](/docs/plugins/migration-guide) for complete migration instructions and code examples. \ No newline at end of file diff --git a/transports/go.mod b/transports/go.mod index 343927274..1e051c61d 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -7,9 +7,10 @@ require ( github.com/bytedance/sonic v1.14.2 github.com/fasthttp/router v1.5.4 github.com/fasthttp/websocket v1.5.12 + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f github.com/google/uuid v1.6.0 github.com/mark3labs/mcp-go v0.43.2 - github.com/maximhq/bifrost/core v1.2.42 + github.com/maximhq/bifrost/core v1.2.43 github.com/maximhq/bifrost/framework v1.1.50 github.com/maximhq/bifrost/plugins/governance v1.3.51 github.com/maximhq/bifrost/plugins/logging v1.3.51 diff --git a/transports/go.sum b/transports/go.sum index c9f0de5aa..c3692caf7 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -133,6 +133,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= @@ -178,8 +180,7 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.42 h1:0G5TD4sZWlT8CwteobFpXmnALGzFQ6lsrDAQl8tr7/k= -github.com/maximhq/bifrost/core v1.2.42/go.mod h1:1msCedjIgC8d9TNJyB1z7s+348vh2Bd0u66qAPgpZoA= +github.com/maximhq/bifrost/core v1.2.43 h1:NxtvzvLL0Isaf8mD1dlGb9JxT7/PZNPv5NBTnqHG100= github.com/maximhq/bifrost/framework v1.1.50 h1:dy4awTo1b51hXll8Uk2HdZTZESXGJJCYcsBUOmzTLqU= github.com/maximhq/bifrost/framework v1.1.50/go.mod h1:opfHztmWhGkg0QSQGvZsU1wXEpMYLpBl01J6LsKECng= github.com/maximhq/bifrost/plugins/governance v1.3.51 h1:ng1F+N9nKD95qvrpld+aQVjn0ZXzT2RvDOOKfKUmVKk= @@ -188,8 +189,7 @@ github.com/maximhq/bifrost/plugins/logging v1.3.51 h1:WC7E+xB54aBp1yHiw1ZhomqPn2 github.com/maximhq/bifrost/plugins/logging v1.3.51/go.mod h1:CaPFNbzK+1HEVERtxEY+BnOYRbkov3Gv+YDtrSL4HWQ= github.com/maximhq/bifrost/plugins/maxim v1.4.51 h1:B3E+TvV4dFjVToH2HQrsMw3tMzCX2+DFvpMjfeeRm8M= github.com/maximhq/bifrost/plugins/maxim v1.4.51/go.mod h1:JtPxw49OzA8XqEmSlfuRc86cdE/+QlebrQ3ps7izozU= -github.com/maximhq/bifrost/plugins/mocker v1.3.52 h1:DgujQBv7IilNEC3K70tu/7unaE89CzNpadwaLoFiCRE= -github.com/maximhq/bifrost/plugins/mocker v1.3.52/go.mod h1:wQlAj8+dbiPB4cQHdsQSL4tuVEVN4Iwz5mw0oPY7E/M= +github.com/maximhq/bifrost/plugins/mocker v1.3.53 h1:5OgTB878Q4dKK+v1DSZcw331qtiJxr/muDLdqJfY8T4= github.com/maximhq/bifrost/plugins/otel v1.0.50 h1:kd2mD4sNbQVZAd2lCaYHRo7WMlH5Heswuw+jUyM+dos= github.com/maximhq/bifrost/plugins/otel v1.0.50/go.mod h1:TR8nrgglI3FZXMM+aedOGGNP1dwS/BaWce6rr0S9OKo= github.com/maximhq/bifrost/plugins/semanticcache v1.3.50 h1:szdnvWtFXZ2e9M9ODts5KRCc5DsOhzIkp81o2JBJtig= diff --git a/transports/version b/transports/version index ff81b39cf..7cd1b30a2 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.3.54 \ No newline at end of file +1.4.0-prerelease1 \ No newline at end of file diff --git a/ui/app/clientLayout.tsx b/ui/app/clientLayout.tsx index 363e16d65..eb26338d1 100644 --- a/ui/app/clientLayout.tsx +++ b/ui/app/clientLayout.tsx @@ -1,5 +1,6 @@ "use client"; +import { DevProfiler } from "@/components/devProfiler"; import FullPageLoader from "@/components/fullPageLoader"; import NotAvailableBanner from "@/components/notAvailableBanner"; import ProgressProvider from "@/components/progressBar"; @@ -12,7 +13,7 @@ import { BifrostConfig } from "@/lib/types/config"; import { RbacProvider } from "@enterprise/lib/contexts/rbacContext"; import { usePathname } from "next/navigation"; import { NuqsAdapter } from "nuqs/adapters/next/app"; -import { Suspense, useEffect } from "react"; +import { useEffect } from "react"; import { toast, Toaster } from "sonner"; function AppContent({ children }: { children: React.ReactNode }) { @@ -28,7 +29,7 @@ function AppContent({ children }: { children: React.ReactNode }) { -
+
{isLoading ? : {children}}
@@ -58,6 +59,7 @@ export function ClientLayout({ children }: { children: React.ReactNode }) { {children} + diff --git a/ui/app/workspace/logs/views/logDetailsSheet.tsx b/ui/app/workspace/logs/views/logDetailsSheet.tsx index e8efb07ca..d3204263b 100644 --- a/ui/app/workspace/logs/views/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/views/logDetailsSheet.tsx @@ -13,13 +13,13 @@ import { } from "@/components/ui/alertDialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; +import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuSeparator, DropdownMenuTrigger } from "@/components/ui/dropdownMenu"; import { DottedSeparator } from "@/components/ui/separator"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; import { RequestTypeColors, RequestTypeLabels, Status, StatusColors } from "@/lib/constants/logs"; import { LogEntry } from "@/lib/types/logs"; -import { Clipboard, DollarSign, FileText, Timer, Trash2 } from "lucide-react"; +import { Clipboard, DollarSign, FileText, MoreVertical, Timer, Trash2 } from "lucide-react"; import moment from "moment"; import { toast } from "sonner"; import { CodeEditor } from "./codeEditor"; @@ -168,7 +168,7 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet return ( - +
@@ -178,44 +178,45 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
-
- - - - - - -

Copy request body JSON

-
-
-
- - - - - - - Are you sure you want to delete this log? - This action cannot be undone. This will permanently delete the log entry. - - - Cancel - { - handleDelete(log); - onOpenChange(false); - }} - > - Delete - - - - -
+ + + + + Copy request body + + + + + + Delete log + + + + + + + Are you sure you want to delete this log? + This action cannot be undone. This will permanently delete the log entry. + + + Cancel + { + handleDelete(log); + onOpenChange(false); + }} + > + Delete + + + +
diff --git a/ui/components/devProfiler.tsx b/ui/components/devProfiler.tsx new file mode 100644 index 000000000..e0c53fcab --- /dev/null +++ b/ui/components/devProfiler.tsx @@ -0,0 +1,422 @@ +'use client' + +import { useGetDevPprofQuery } from '@/lib/store' +import { isDevelopmentMode } from '@/lib/utils/port' +import { Activity, ChevronDown, ChevronUp, Cpu, HardDrive, X } from 'lucide-react' +import React, { useCallback, useMemo, useState } from 'react' +import { + Area, + AreaChart, + CartesianGrid, + ResponsiveContainer, + Tooltip, + XAxis, + YAxis, +} from 'recharts' + +// Format bytes to human-readable string +function formatBytes (bytes: number): string { + if (bytes === 0) return '0 B' + const k = 1024 + const sizes = ['B', 'KB', 'MB', 'GB'] + const i = Math.floor(Math.log(bytes) / Math.log(k)) + return `${(bytes / Math.pow(k, i)).toFixed(1)} ${sizes[i]}` +} + +// Format nanoseconds to human-readable string +function formatNs (ns: number): string { + if (ns < 1000) return `${ns}ns` + if (ns < 1000000) return `${(ns / 1000).toFixed(1)}µs` + if (ns < 1000000000) return `${(ns / 1000000).toFixed(1)}ms` + return `${(ns / 1000000000).toFixed(2)}s` +} + +// Format timestamp to HH:MM:SS +function formatTime (timestamp: string): string { + const date = new Date(timestamp) + return date.toLocaleTimeString('en-US', { + hour12: false, + hour: '2-digit', + minute: '2-digit', + second: '2-digit', + }) +} + +// Truncate function name for display +function truncateFunction (fn: string): string { + const parts = fn.split('/') + const last = parts[parts.length - 1] + if (last.length > 40) { + return '...' + last.slice(-37) + } + return last +} + +export function DevProfiler (): React.ReactNode { + const [isVisible, setIsVisible] = useState(true) + const [isExpanded, setIsExpanded] = useState(true) + const [isDismissed, setIsDismissed] = useState(false) + + // Only fetch in development mode and when not dismissed + const shouldFetch = isDevelopmentMode() && !isDismissed + + const { data, isLoading, error } = useGetDevPprofQuery(undefined, { + pollingInterval: shouldFetch ? 10000 : 0, // Poll every 10 seconds + skip: !shouldFetch, + }) + + // Memoize chart data transformation + const memoryChartData = useMemo(() => { + if (!data?.history) return [] + return data.history.map((point) => ({ + time: formatTime(point.timestamp), + alloc: point.alloc / (1024 * 1024), // Convert to MB + heapInuse: point.heap_inuse / (1024 * 1024), + })) + }, [data?.history]) + + const cpuChartData = useMemo(() => { + if (!data?.history) return [] + return data.history.map((point) => ({ + time: formatTime(point.timestamp), + cpuPercent: point.cpu_percent, + goroutines: point.goroutines, + })) + }, [data?.history]) + + const handleDismiss = useCallback(() => { + setIsDismissed(true) + }, []) + + const handleToggleExpand = useCallback(() => { + setIsExpanded((prev) => !prev) + }, []) + + const handleToggleVisible = useCallback(() => { + setIsVisible((prev) => !prev) + }, []) + + // Don't render in production mode or if dismissed + if (!isDevelopmentMode() || isDismissed) { + return null + } + + // Minimized state - just show a small button + if (!isVisible) { + return ( + + ) + } + + return ( +
+ {/* Header */} +
+
+ Dev Profiler + {isLoading && ( + + )} +
+
+ + + +
+
+ + {Boolean(error) && ( +
+ Failed to load profiling data +
+ )} + + {isExpanded && data && ( +
+ {/* Current Stats */} +
+
+ CPU Usage + + {data.cpu.usage_percent.toFixed(1)}% + +
+
+ Heap Alloc + + {formatBytes(data.memory.alloc)} + +
+
+ Heap In-Use + + {formatBytes(data.memory.heap_inuse)} + +
+
+ System + + {formatBytes(data.memory.sys)} + +
+
+ Goroutines + + {data.runtime.num_goroutine} + +
+
+ GC Pause + + {formatNs(data.runtime.gc_pause_ns)} + +
+
+ + {/* CPU Chart */} +
+
+ + CPU Usage (last 5 min) +
+
+ + + + + + + + + + + + + + + `${Number(v).toFixed(0)}%`} + width={35} + domain={[0, 'auto']} + /> + + + + + + +
+
+ + + CPU % + + + + Goroutines + +
+
+ + {/* Memory Chart */} +
+
+ + Memory (last 5 min) +
+
+ + + + + + + + + + + + + + + `${Number(v).toFixed(0)}MB`} + width={45} + /> + + + + + +
+
+ + + Alloc + + + + Heap In-Use + +
+
+ + {/* Top Allocations */} +
+
+ + Top Allocations +
+
+ {(data.top_allocations ?? []).map((alloc, i) => ( +
+
+ + {truncateFunction(alloc.function)} + + + {alloc.file}:{alloc.line} + +
+
+ + {formatBytes(alloc.bytes)} + + + {alloc.count.toLocaleString()} allocs + +
+
+ ))} +
+
+ + {/* Footer with info */} +
+ CPUs: {data.runtime.num_cpu} | GOMAXPROCS: {data.runtime.gomaxprocs} | + GC: {data.runtime.num_gc} | Objects: {data.memory.heap_objects.toLocaleString()} +
+
+ )} + + {/* Collapsed state */} + {!isExpanded && data && ( +
+ + CPU: {data.cpu.usage_percent.toFixed(1)}% + + + Heap: {formatBytes(data.memory.heap_inuse)} + + + Goroutines: {data.runtime.num_goroutine} + +
+ )} +
+ ) +} diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index af8717413..367f2f7e1 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -205,7 +205,7 @@ const SidebarItemView = ({ return ( , SwitchProps>( - ({ className, size = "default", ...props }, ref) => ( + ({ className, size = "md", ...props }, ref) => ( , > ({ + // Get dev pprof data - polls every 10 seconds + getDevPprof: builder.query({ + query: () => ({ + url: '/dev/pprof', + }), + }), + }), +}) + +export const { + useGetDevPprofQuery, + useLazyGetDevPprofQuery, +} = devApi + diff --git a/ui/lib/store/apis/index.ts b/ui/lib/store/apis/index.ts index 99fc3eb56..3946f3318 100644 --- a/ui/lib/store/apis/index.ts +++ b/ui/lib/store/apis/index.ts @@ -3,6 +3,7 @@ export { baseApi, clearAuthStorage, getErrorMessage, setAuthToken } from "./base // API slices and hooks export * from "./configApi"; +export * from "./devApi"; export * from "./governanceApi"; export * from "./logsApi"; export * from "./mcpApi";