diff --git a/core/changelog.md b/core/changelog.md index 8b06ca4fa..4c2e2ebe1 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,3 +1,5 @@ +- feat: add handling for HTML and empty responses from providers +- feat: add audio token pricing support for models - feat: adds new parameter for each provider key config `use_for_batch_apis`. This helps users to select which APIs or accounts to be used for Batch APIs. - feat: adds s3 bucket config support for Bedrock provider. - feat: prompt caching support for anthropic and bedrock(claude and nova models) diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index 3443e5132..128b1f4c9 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -522,6 +522,17 @@ func (provider *ElevenlabsProvider) Transcription(ctx context.Context, key schem return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) } + // Check for empty response + trimmed := strings.TrimSpace(string(responseBody)) + if len(trimmed) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseEmpty, + }, + } + } + chunks, err := parseTranscriptionResponse(responseBody) if err != nil { return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index eece5e224..e65445a92 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -299,12 +299,36 @@ func (provider *MistralProvider) Transcription(ctx context.Context, key schemas. return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) } - // Copy response body before releasing - responseBody := append([]byte(nil), resp.Body()...) + responseBody, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + } + + // Check for empty response + trimmed := strings.TrimSpace(string(responseBody)) + if len(trimmed) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseEmpty, + }, + } + } + + copiedResponseBody := append([]byte(nil), responseBody...) // Parse Mistral's transcription response var mistralResponse MistralTranscriptionResponse - if err := sonic.Unmarshal(responseBody, &mistralResponse); err != nil { + if err := sonic.Unmarshal(copiedResponseBody, &mistralResponse); err != nil { + if providerUtils.IsHTMLResponse(resp, copiedResponseBody) { + errorMessage := providerUtils.ExtractHTMLErrorMessage(copiedResponseBody) + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: errorMessage, + }, + } + } return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } @@ -323,7 +347,7 @@ func (provider *MistralProvider) Transcription(ctx context.Context, key schemas. // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} - if err := sonic.Unmarshal(responseBody, &rawResponse); err == nil { + if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err == nil { response.ExtraFields.RawResponse = rawResponse } } @@ -441,7 +465,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo // Process accumulated event if we have both event and data if currentEvent != "" && currentData != "" { chunkIndex++ - provider.processStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan) + provider.processTranscriptionStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan) } // Reset for next event currentEvent = "" @@ -460,7 +484,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo // Process any remaining event if currentEvent != "" && currentData != "" { chunkIndex++ - provider.processStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan) + provider.processTranscriptionStreamEvent(ctx, postHookRunner, currentEvent, currentData, request.Model, providerName, chunkIndex, startTime, &lastChunkTime, responseChan) } // Handle scanner errors @@ -473,8 +497,8 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo return responseChan, nil } -// processStreamEvent processes a single SSE event and sends it to the response channel. -func (provider *MistralProvider) processStreamEvent( +// processTranscriptionStreamEvent processes a single SSE event and sends it to the response channel. +func (provider *MistralProvider) processTranscriptionStreamEvent( ctx context.Context, postHookRunner schemas.PostHookRunner, eventType string, diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 4273d65bf..d5f10a85f 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -1890,17 +1890,40 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.K return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) } + // Check for empty response + trimmed := strings.TrimSpace(string(responseBody)) + if len(trimmed) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseEmpty, + }, + } + } + + copiedResponseBody := append([]byte(nil), responseBody...) + // Parse OpenAI's transcription response directly into BifrostTranscribe response := &schemas.BifrostTranscriptionResponse{} - if err := sonic.Unmarshal(responseBody, response); err != nil { + if err := sonic.Unmarshal(copiedResponseBody, response); err != nil { + // Check if it's an HTML response + if providerUtils.IsHTMLResponse(resp, copiedResponseBody) { + errorMessage := providerUtils.ExtractHTMLErrorMessage(copiedResponseBody) + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: errorMessage, + }, + } + } return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } // Parse raw response for RawResponse field var rawResponse interface{} if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { - if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { + if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err != nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) } } diff --git a/core/providers/utils/html_response_handler_test.go b/core/providers/utils/html_response_handler_test.go new file mode 100644 index 000000000..a9b31d8a9 --- /dev/null +++ b/core/providers/utils/html_response_handler_test.go @@ -0,0 +1,302 @@ +package utils + +import ( + "strings" + "testing" + + "github.com/valyala/fasthttp" +) + +func TestIsHTMLResponse(t *testing.T) { + tests := []struct { + name string + contentType string + body []byte + expectedIsHTML bool + description string + }{ + { + name: "HTML with Content-Type header", + contentType: "text/html; charset=utf-8", + body: []byte("
Error"), + expectedIsHTML: true, + description: "Should detect HTML from Content-Type header", + }, + { + name: "HTML without Content-Type", + contentType: "application/octet-stream", + body: []byte("The page was not found
+ + `), + expectMsg: "404 Not Found", + description: "Should extract title from title tag", + }, + { + name: "Extract from h1 tag", + htmlBody: []byte(` + + + +The service is currently unavailable
+ + + `), + expectMsg: "Service Unavailable", + description: "Should extract from h1 tag when title is missing", + }, + { + name: "Extract from h2 tag", + htmlBody: []byte(` + + + +Please check your credentials
+ + + `), + expectMsg: "Authentication Failed", + description: "Should extract from h2 tag with attributes", + }, + { + name: "Extract visible text when no headers", + htmlBody: []byte(` + + +Access denied
+ + + `), + description: "Should detect and extract message from HTML on parse failure", + expectedInMessage: "Forbidden", + }, + { + name: "Invalid JSON with HTML fallback", + statusCode: 400, + contentType: "application/json", + body: []byte(`not valid json`), + description: "Should fall back to raw string when not HTML", + expectedInMessage: "provider API error", + }, + { + name: "Valid JSON error response", + statusCode: 400, + contentType: "application/json", + body: []byte(`{"error": {"message": "Invalid request"}, "code": "invalid_request"}`), + description: "Should handle valid JSON without HTML detection", + expectedInMessage: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &fasthttp.Response{} + resp.SetStatusCode(tt.statusCode) + resp.Header.Set("Content-Type", tt.contentType) + resp.SetBody(tt.body) + + var errorResp map[string]interface{} + bifrostErr := HandleProviderAPIError(resp, &errorResp) + + if bifrostErr == nil { + t.Errorf("HandleProviderAPIError() returned nil error") + return + } + + if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != tt.statusCode { + t.Errorf("HandleProviderAPIError() status code = %v, want %v", bifrostErr.StatusCode, tt.statusCode) + } + + if bifrostErr.Error == nil { + t.Errorf("HandleProviderAPIError() error field is nil") + return + } + + // Check if expected message is in the response + if tt.expectedInMessage != "" && !strings.Contains(bifrostErr.Error.Message, tt.expectedInMessage) { + t.Errorf("Expected message to contain %q, got %q", tt.expectedInMessage, bifrostErr.Error.Message) + } + + t.Logf("Handled %s: status=%d, message=%q", tt.name, *bifrostErr.StatusCode, bifrostErr.Error.Message) + }) + } +} + +func BenchmarkIsHTMLResponse(b *testing.B) { + resp := &fasthttp.Response{} + resp.Header.Set("Content-Type", "text/html; charset=utf-8") + body := []byte(`This is a detailed error message
`) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ExtractHTMLErrorMessage(body) + } +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index c1a67647e..64e2bf86e 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -11,6 +11,7 @@ import ( "net/http" "net/textproto" "net/url" + "regexp" "slices" "sort" "strings" @@ -314,12 +315,13 @@ func SetExtraHeadersHTTP(ctx context.Context, req *http.Request, extraHeaders ma // HandleProviderAPIError processes error responses from provider APIs. // It attempts to unmarshal the error response and returns a BifrostError // with the appropriate status code and error information. -// errorResp must be a pointer to the target struct for unmarshaling. +// HTML detection only runs if JSON parsing fails to avoid expensive regex operations +// on responses that are almost certainly valid JSON. errorResp must be a pointer to +// the target struct for unmarshaling. func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError { statusCode := resp.StatusCode() - body := append([]byte(nil), resp.Body()...) - // decode body + // Decode body decodedBody, err := CheckAndDecodeBody(resp) if err != nil { return &schemas.BifrostError{ @@ -331,24 +333,48 @@ func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif } } - body = decodedBody + // Check for empty response + trimmed := strings.TrimSpace(string(decodedBody)) + if len(trimmed) == 0 { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseEmpty, + }, + } + } + + // Try JSON parsing first + if err := sonic.Unmarshal(decodedBody, errorResp); err == nil { + // JSON parsing succeeded, return success + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: &schemas.ErrorField{}, + } + } - if err := sonic.Unmarshal(body, errorResp); err != nil { - rawResponse := body - message := fmt.Sprintf("provider API error: %s", string(rawResponse)) + // JSON parsing failed - now check if it's an HTML response (expensive operation) + if IsHTMLResponse(resp, decodedBody) { + errorMessage := ExtractHTMLErrorMessage(decodedBody) return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Error: &schemas.ErrorField{ - Message: message, + Message: errorMessage, }, } } + // Not HTML either - return raw response as error message + message := fmt.Sprintf("provider API error: %s", string(decodedBody)) return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, - Error: &schemas.ErrorField{}, + Error: &schemas.ErrorField{ + Message: message, + }, } } @@ -356,7 +382,20 @@ func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif // It attempts to parse the response body into the provided response type // and returns either the parsed response or a BifrostError if parsing fails. // If sendBackRawResponse is true, it returns the raw response interface, otherwise nil. +// HTML detection only runs if JSON parsing fails to avoid expensive regex operations +// on responses that are almost certainly valid JSON. func HandleProviderResponse[T any](responseBody []byte, response *T, requestBody []byte, sendBackRawRequest bool, sendBackRawResponse bool) (rawRequest interface{}, rawResponse interface{}, bifrostErr *schemas.BifrostError) { + // Check for empty response + trimmed := strings.TrimSpace(string(responseBody)) + if len(trimmed) == 0 { + return nil, nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseEmpty, + }, + } + } + var wg sync.WaitGroup var structuredErr, rawRequestErr, rawResponseErr error @@ -394,6 +433,18 @@ func HandleProviderResponse[T any](responseBody []byte, response *T, requestBody wg.Wait() if structuredErr != nil { + // JSON parsing failed - check if it's an HTML response (expensive operation) + if IsHTMLResponse(nil, responseBody) { + errorMessage := ExtractHTMLErrorMessage(responseBody) + return nil, nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseHTML, + Error: errors.New(errorMessage), + }, + } + } + return nil, nil, &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ @@ -441,6 +492,7 @@ func HandleProviderResponse[T any](responseBody []byte, response *T, requestBody return nil, nil, nil } +// ParseAndSetRawRequest parses the raw request body and sets it in the extra fields. func ParseAndSetRawRequest(extraFields *schemas.BifrostResponseExtraFields, jsonBody []byte) { var rawRequest interface{} if err := sonic.Unmarshal(jsonBody, &rawRequest); err != nil { @@ -495,6 +547,123 @@ func CheckAndDecodeBody(resp *fasthttp.Response) ([]byte, error) { } } +// IsHTMLResponse checks if the response is HTML by examining the Content-Type header +// and/or the response body for HTML indicators. +func IsHTMLResponse(resp *fasthttp.Response, body []byte) bool { + // Check Content-Type header first (most reliable indicator) + if resp != nil { + contentType := strings.ToLower(string(resp.Header.Peek("Content-Type"))) + if strings.Contains(contentType, "text/html") { + return true + } + } + + // If body is small, it's unlikely to be HTML + if len(body) < 20 { + return false + } + + // Check for HTML indicators in body + bodyLower := strings.ToLower(string(body)) + + // Look for common HTML tags or DOCTYPE + htmlIndicators := []string{ + "", + "", + "