Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
feat: send back raw request in extra fields
feat: added support for reasoning in chat completions
feat: enhanced reasoning support in responses api
enhancement: improved internal inter provider conversions for integrations
43 changes: 38 additions & 5 deletions core/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type AnthropicProvider struct {
client *fasthttp.Client // HTTP client for API requests
apiVersion string // API version for the provider
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
customProviderConfig *schemas.CustomProviderConfig // Custom provider config
}
Expand Down Expand Up @@ -103,6 +104,7 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
client: client,
apiVersion: "2023-06-01",
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
customProviderConfig: config.CustomProviderConfig,
}
Expand Down Expand Up @@ -199,7 +201,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key sche

// Parse Anthropic's response
var anthropicResponse AnthropicListModelsResponse
rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &anthropicResponse, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &anthropicResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -208,6 +210,11 @@ func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key sche
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models)
response.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -264,7 +271,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem
response := acquireAnthropicTextResponse()
defer releaseAnthropicTextResponse(response)

rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -277,6 +284,11 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem
bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -320,11 +332,10 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
response := AcquireAnthropicMessageResponse()
defer ReleaseAnthropicMessageResponse(response)

rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}

// Create final response
bifrostResponse := response.ToBifrostChatResponse()

Expand All @@ -334,6 +345,11 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem
bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -386,6 +402,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos
jsonData,
headers,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand All @@ -403,6 +420,7 @@ func HandleAnthropicChatCompletionStreaming(
jsonBody []byte,
headers map[string]string,
extraHeaders map[string]string,
sendBackRawRequest bool,
sendBackRawResponse bool,
providerName schemas.ModelProvider,
postHookRunner schemas.PostHookRunner,
Expand Down Expand Up @@ -635,6 +653,10 @@ func HandleAnthropicChatCompletionStreaming(
return
}
}
// Set raw request if enabled
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
}
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan)
Expand Down Expand Up @@ -672,7 +694,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
response := AcquireAnthropicMessageResponse()
defer ReleaseAnthropicMessageResponse(response)

rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -686,6 +708,11 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke
bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -735,6 +762,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook
jsonBody,
headers,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand All @@ -752,6 +780,7 @@ func HandleAnthropicResponsesStream(
jsonBody []byte,
headers map[string]string,
extraHeaders map[string]string,
sendBackRawRequest bool,
sendBackRawResponse bool,
providerName schemas.ModelProvider,
postHookRunner schemas.PostHookRunner,
Expand Down Expand Up @@ -959,6 +988,10 @@ func HandleAnthropicResponsesStream(
response.Response = &schemas.BifrostResponsesResponse{}
}
response.Response.Usage = usage
// Set raw request if enabled
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
}
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan)
Expand Down
50 changes: 43 additions & 7 deletions core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type AzureProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}

Expand All @@ -47,6 +48,7 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A
logger: logger,
client: client,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
Expand Down Expand Up @@ -192,7 +194,7 @@ func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas.

// Parse Azure-specific response
azureResponse := &AzureListModelsResponse{}
rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, azureResponse, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, azureResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -205,6 +207,12 @@ func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas.

response.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
Expand Down Expand Up @@ -263,7 +271,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, key schemas.K

response := &schemas.BifrostTextCompletionResponse{}

rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -274,6 +282,11 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, key schemas.K
response.ExtraFields.RequestType = schemas.TextCompletionRequest
response.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -324,6 +337,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand Down Expand Up @@ -388,18 +402,19 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.K
}

response := &schemas.BifrostChatResponse{}
var rawRequest interface{}
var rawResponse interface{}

if schemas.IsAnthropicModel(deployment) {
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
response = anthropicResponse.ToBifrostChatResponse()
} else {
rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -411,6 +426,11 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.K
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.RequestType = schemas.ChatCompletionRequest

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -472,6 +492,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
jsonData,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand Down Expand Up @@ -499,6 +520,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand Down Expand Up @@ -570,18 +592,19 @@ func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, r
}

response := &schemas.BifrostResponsesResponse{}
var rawRequest interface{}
var rawResponse interface{}

if schemas.IsAnthropicModel(deployment) {
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
response = anthropicResponse.ToBifrostResponsesResponse()
} else {
rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -593,6 +616,11 @@ func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, r
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.RequestType = schemas.ResponsesRequest

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -651,6 +679,7 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn
jsonData,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand Down Expand Up @@ -679,6 +708,7 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
Expand Down Expand Up @@ -728,7 +758,7 @@ func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, r
response := &schemas.BifrostEmbeddingResponse{}

// Use enhanced response handler with pre-allocated response
rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, provider.sendBackRawResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -739,6 +769,12 @@ func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, r
response.ExtraFields.ModelDeployment = deployment
response.ExtraFields.RequestType = schemas.EmbeddingRequest

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
Expand Down
Loading