diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 53ca03c58..f6daeb356 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -479,7 +479,7 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema case schemas.Gemini: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 120, + DefaultRequestTimeoutInSeconds: 300, MaxRetries: 10, // Gemini can be variable RetryBackoffInitial: 750 * time.Millisecond, RetryBackoffMax: 12 * time.Second, diff --git a/core/internal/testutil/transcription.go b/core/internal/testutil/transcription.go index 0e5c1d18b..3792efc1a 100644 --- a/core/internal/testutil/transcription.go +++ b/core/internal/testutil/transcription.go @@ -34,21 +34,21 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con name: "RoundTrip_Basic_MP3", text: TTSTestTextBasic, voiceType: "primary", - format: "mp3", + format: "wav", responseFormat: bifrost.Ptr("json"), }, { name: "RoundTrip_Medium_MP3", text: TTSTestTextMedium, voiceType: "secondary", - format: "mp3", + format: "wav", responseFormat: bifrost.Ptr("json"), }, { name: "RoundTrip_Technical_MP3", text: TTSTestTextTechnical, voiceType: "tertiary", - format: "mp3", + format: "wav", responseFormat: bifrost.Ptr("json"), }, } @@ -61,6 +61,8 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con // Step 1: Generate TTS audio voice := GetProviderVoice(testConfig.Provider, tc.voiceType) + responseFormat := GetProviderResponseFormat(testConfig.Provider, tc.format) + ttsRequest := &schemas.BifrostSpeechRequest{ Provider: testConfig.Provider, Model: testConfig.SpeechSynthesisModel, @@ -71,7 +73,7 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con VoiceConfig: &schemas.SpeechVoiceInput{ Voice: &voice, }, - ResponseFormat: tc.format, + ResponseFormat: responseFormat, }, Fallbacks: testConfig.TranscriptionFallbacks, } diff --git a/core/internal/testutil/transcription_stream.go b/core/internal/testutil/transcription_stream.go index 25d4e31dd..80f55ee91 100644 --- a/core/internal/testutil/transcription_stream.go +++ b/core/internal/testutil/transcription_stream.go @@ -9,7 +9,6 @@ import ( "testing" "time" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -65,6 +64,7 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte // Step 1: Generate TTS audio voice := GetProviderVoice(testConfig.Provider, tc.voiceType) + responseFormat := GetProviderResponseFormat(testConfig.Provider, tc.format) ttsRequest := &schemas.BifrostSpeechRequest{ Provider: testConfig.Provider, Model: testConfig.SpeechSynthesisModel, @@ -75,7 +75,7 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte VoiceConfig: &schemas.SpeechVoiceInput{ Voice: &voice, }, - ResponseFormat: tc.format, + ResponseFormat: responseFormat, }, Fallbacks: testConfig.TranscriptionFallbacks, } diff --git a/core/internal/testutil/utils.go b/core/internal/testutil/utils.go index 50d0826de..0713713fa 100644 --- a/core/internal/testutil/utils.go +++ b/core/internal/testutil/utils.go @@ -85,6 +85,19 @@ func GetProviderVoice(provider schemas.ModelProvider, voiceType string) string { } } +// GetProviderResponseFormat returns the appropriate response format for speech synthesis based on the provider +// For Gemini, only "wav" format is supported, so we always return "wav" regardless of the requested format +func GetProviderResponseFormat(provider schemas.ModelProvider, requestedFormat string) string { + switch provider { + case schemas.Gemini: + // Gemini only supports wav format for speech synthesis + return "wav" + default: + // Other providers support the requested format + return requestedFormat + } +} + type SampleToolType string const ( @@ -539,6 +552,9 @@ func GenerateTTSAudioForTest(ctx context.Context, t *testing.T, client *bifrost. format = "mp3" } + // Get the appropriate response format for the provider + responseFormat := GetProviderResponseFormat(provider, format) + req := &schemas.BifrostSpeechRequest{ Provider: provider, Model: ttsModel, @@ -547,7 +563,7 @@ func GenerateTTSAudioForTest(ctx context.Context, t *testing.T, client *bifrost. VoiceConfig: &schemas.SpeechVoiceInput{ Voice: &voice, }, - ResponseFormat: format, + ResponseFormat: responseFormat, }, } diff --git a/core/providers/gemini/gemini_test.go b/core/providers/gemini/gemini_test.go index 0ecbe2204..bc40ce8a5 100644 --- a/core/providers/gemini/gemini_test.go +++ b/core/providers/gemini/gemini_test.go @@ -48,8 +48,8 @@ func TestGemini(t *testing.T) { MultipleImages: false, CompleteEnd2End: true, Embedding: true, - Transcription: false, - TranscriptionStream: false, + Transcription: true, + TranscriptionStream: true, SpeechSynthesis: true, SpeechSynthesisStream: true, Reasoning: true, diff --git a/core/providers/gemini/transcription.go b/core/providers/gemini/transcription.go index 399ac4df4..c82a086ba 100644 --- a/core/providers/gemini/transcription.go +++ b/core/providers/gemini/transcription.go @@ -13,6 +13,9 @@ func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest() *schemas bifrostReq := &schemas.BifrostTranscriptionRequest{ Provider: provider, Model: model, + Params: &schemas.TranscriptionParameters{ + ExtraParams: make(map[string]interface{}), + }, } // Extract audio data and prompt from contents @@ -60,11 +63,6 @@ func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest() *schemas File: audioData, } - // Set parameters - if bifrostReq.Params == nil { - bifrostReq.Params = &schemas.TranscriptionParameters{} - } - // Set prompt if provided if promptText != "" { bifrostReq.Params.Prompt = &promptText @@ -72,25 +70,16 @@ func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest() *schemas // Handle safety settings from request if len(request.SafetySettings) > 0 { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } bifrostReq.Params.ExtraParams["safety_settings"] = request.SafetySettings } // Handle cached content if request.CachedContent != "" { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } bifrostReq.Params.ExtraParams["cached_content"] = request.CachedContent } // Handle labels if len(request.Labels) > 0 { - if bifrostReq.Params.ExtraParams == nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - } bifrostReq.Params.ExtraParams["labels"] = request.Labels }