diff --git a/core/bifrost.go b/core/bifrost.go index 79ca31272..1e1e445c1 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -867,6 +867,73 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *sch return bifrost.handleStreamRequest(ctx, bifrostReq) } +// ImageGenerationRequest sends a image generation request to the specified provider. +func (bifrost *Bifrost) ImageGenerationRequest(ctx context.Context, + req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + if req == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "image generation request is nil", + }, + } + } + if req.Input == nil || req.Input.Prompt == "" { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "prompt not provided for image generation request", + }, + } + } + + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.RequestType = schemas.ImageGenerationRequest + bifrostReq.ImageGenerationRequest = req + + response, err := bifrost.handleRequest(ctx, bifrostReq) + if err != nil { + return nil, err + } + if response == nil || response.ImageGenerationResponse == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "received nil response from provider", + }, + } + } + + return response.ImageGenerationResponse, nil +} + +// ImageGenerationStreamRequest sends a image generation stream request to the specified provider. +func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx context.Context, + req *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "image generation stream request is nil", + }, + } + } + if req.Input == nil || req.Input.Prompt == "" { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "prompt not provided for image generation stream request", + }, + } + } + + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.RequestType = schemas.ImageGenerationStreamRequest + bifrostReq.ImageGenerationRequest = req + + return bifrost.handleStreamRequest(ctx, bifrostReq) +} + // RemovePlugin removes a plugin from the server. func (bifrost *Bifrost) RemovePlugin(name string) error { @@ -1688,6 +1755,12 @@ func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fall tmp.Model = fallback.Model fallbackReq.TranscriptionRequest = &tmp } + if req.ImageGenerationRequest != nil { + tmp := *req.ImageGenerationRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.ImageGenerationRequest = &tmp + } return &fallbackReq } @@ -2424,6 +2497,12 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch return nil, bifrostError } response.TranscriptionResponse = transcriptionResponse + case schemas.ImageGenerationRequest: + imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.ImageGenerationResponse = imageResponse default: _, model, _ := req.BifrostRequest.GetRequestFields() return nil, &schemas.BifrostError{ @@ -2454,6 +2533,8 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r return provider.SpeechStream(req.Context, postHookRunner, key, req.BifrostRequest.SpeechRequest) case schemas.TranscriptionStreamRequest: return provider.TranscriptionStream(req.Context, postHookRunner, key, req.BifrostRequest.TranscriptionRequest) + case schemas.ImageGenerationStreamRequest: + return provider.ImageGenerationStream(req.Context, postHookRunner, key, req.BifrostRequest.ImageGenerationRequest) default: _, model, _ := req.BifrostRequest.GetRequestFields() return nil, &schemas.BifrostError{ @@ -2630,6 +2711,7 @@ func resetBifrostRequest(req *schemas.BifrostRequest) { req.EmbeddingRequest = nil req.SpeechRequest = nil req.TranscriptionRequest = nil + req.ImageGenerationRequest = nil } // getBifrostRequest gets a BifrostRequest from the pool diff --git a/core/changelog.md b/core/changelog.md index e69de29bb..eaede1ce1 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1 @@ +feat: added image generation request and response support \ No newline at end of file diff --git a/core/images_test.go b/core/images_test.go new file mode 100644 index 000000000..eea608437 --- /dev/null +++ b/core/images_test.go @@ -0,0 +1,539 @@ +package bifrost + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + "testing" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" +) + +// TestImageGenerationRequestSerialization tests Bifrost request JSON serialization +func TestImageGenerationRequestSerialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + request *schemas.BifrostImageGenerationRequest + check func(t *testing.T, jsonBytes []byte) + }{ + { + name: "full request serializes correctly", + request: &schemas.BifrostImageGenerationRequest{ + Provider: schemas.OpenAI, + Model: "dall-e-3", + Input: &schemas.ImageGenerationInput{ + Prompt: "a cute cat", + }, + Params: &schemas.ImageGenerationParameters{ + N: schemas.Ptr(2), + Size: schemas.Ptr("1024x1024"), + Quality: schemas.Ptr("hd"), + Style: schemas.Ptr("vivid"), + ResponseFormat: schemas.Ptr("b64_json"), + User: schemas.Ptr("test-user"), + }, + }, + check: func(t *testing.T, jsonBytes []byte) { + var data map[string]interface{} + if err := json.Unmarshal(jsonBytes, &data); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if data["provider"] != "openai" { + t.Errorf("Expected provider 'openai', got %v", data["provider"]) + } + if data["model"] != "dall-e-3" { + t.Errorf("Expected model 'dall-e-3', got %v", data["model"]) + } + input, ok := data["input"].(map[string]interface{}) + if !ok { + t.Fatalf("Failed to assert 'input' as map[string]interface{}, got %T", data["input"]) + } + if input["prompt"] != "a cute cat" { + t.Errorf("Expected prompt 'a cute cat', got %v", input["prompt"]) + } + params, ok := data["params"].(map[string]interface{}) + if !ok { + t.Fatalf("Failed to assert 'params' as map[string]interface{}, got %T", data["params"]) + } + if params["size"] != "1024x1024" { + t.Errorf("Expected size '1024x1024', got %v", params["size"]) + } + }, + }, + { + name: "minimal request omits nil fields", + request: &schemas.BifrostImageGenerationRequest{ + Provider: schemas.OpenAI, + Model: "dall-e-2", + Input: &schemas.ImageGenerationInput{ + Prompt: "test", + }, + }, + check: func(t *testing.T, jsonBytes []byte) { + var data map[string]interface{} + if err := json.Unmarshal(jsonBytes, &data); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if _, exists := data["params"]; exists { + t.Errorf("params should be omitted when nil") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + jsonBytes, err := sonic.Marshal(tt.request) + if err != nil { + t.Fatalf("Serialization failed: %v", err) + } + tt.check(t, jsonBytes) + }) + } +} + +// TestImageGenerationRequestDeserialization tests JSON to Bifrost request deserialization +func TestImageGenerationRequestDeserialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + jsonInput string + validate func(t *testing.T, req *schemas.BifrostImageGenerationRequest) + wantErr bool + }{ + { + name: "full JSON deserializes correctly", + jsonInput: `{ + "provider": "openai", + "model": "dall-e-3", + "input": {"prompt": "a beautiful sunset"}, + "params": {"size": "1024x1024", "quality": "hd", "n": 2} + }`, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req.Provider != schemas.OpenAI { + t.Errorf("Expected provider OpenAI, got %s", req.Provider) + } + if req.Input.Prompt != "a beautiful sunset" { + t.Errorf("Expected prompt 'a beautiful sunset', got '%s'", req.Input.Prompt) + } + if req.Params == nil || *req.Params.N != 2 { + t.Errorf("Expected n=2") + } + }, + wantErr: false, + }, + { + name: "invalid JSON returns error", + jsonInput: `{invalid}`, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) {}, + wantErr: true, + }, + { + name: "missing optional fields succeeds", + jsonInput: `{ + "provider": "openai", + "model": "dall-e-2", + "input": {"prompt": "test"} + }`, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req.Params != nil { + t.Errorf("Expected params to be nil") + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var req schemas.BifrostImageGenerationRequest + err := sonic.Unmarshal([]byte(tt.jsonInput), &req) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + tt.validate(t, &req) + } + }) + } +} + +// TestImageGenerationResponseSerialization tests response serialization +func TestImageGenerationResponseSerialization(t *testing.T) { + t.Parallel() + + resp := &schemas.BifrostImageGenerationResponse{ + ID: "img-123", + Created: 1234567890, + Model: "dall-e-3", + Data: []schemas.ImageData{ + {URL: "https://example.com/image.png", Index: 0, RevisedPrompt: "a cat revised"}, + {B64JSON: "iVBORw0KGgo=", Index: 1}, + }, + Usage: &schemas.ImageUsage{PromptTokens: 10, TotalTokens: 20}, + } + + jsonBytes, err := sonic.Marshal(resp) + if err != nil { + t.Fatalf("Serialization failed: %v", err) + } + + var data map[string]interface{} + if err := json.Unmarshal(jsonBytes, &data); err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + if data["id"] != "img-123" { + t.Errorf("Expected id 'img-123', got %v", data["id"]) + } + + dataArr, ok := data["data"].([]interface{}) + if !ok { + t.Fatalf("Failed to assert 'data' as []interface{}, got %T", data["data"]) + } + if len(dataArr) != 2 { + t.Errorf("Expected 2 images, got %d", len(dataArr)) + } +} + +// TestImageStreamResponseSerialization tests streaming response serialization +func TestImageStreamResponseSerialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + chunk *schemas.BifrostImageGenerationStreamResponse + verify func(t *testing.T, data map[string]interface{}) + }{ + { + name: "partial chunk", + chunk: &schemas.BifrostImageGenerationStreamResponse{ + ID: "img-stream-1", + Type: "image_generation.partial_image", + Index: 0, + ChunkIndex: 3, + PartialB64: "dGVzdGRhdGE=", + }, + verify: func(t *testing.T, data map[string]interface{}) { + if data["type"] != "image_generation.partial_image" { + t.Errorf("Expected type 'image_generation.partial_image'") + } + chunkIndex, ok := data["chunk_index"].(float64) + if !ok { + t.Fatalf("Failed to assert 'chunk_index' as float64, got %T", data["chunk_index"]) + } + if int(chunkIndex) != 3 { + t.Errorf("Expected chunk_index 3, got %d", int(chunkIndex)) + } + }, + }, + { + name: "completed chunk with usage", + chunk: &schemas.BifrostImageGenerationStreamResponse{ + ID: "img-stream-1", + Type: "image_generation.completed", + Index: 0, + ChunkIndex: 10, + Usage: &schemas.ImageUsage{PromptTokens: 5, TotalTokens: 15}, + }, + verify: func(t *testing.T, data map[string]interface{}) { + if data["type"] != "image_generation.completed" { + t.Errorf("Expected type 'image_generation.completed'") + } + usage, ok := data["usage"].(map[string]interface{}) + if !ok { + t.Fatalf("Failed to assert 'usage' as map[string]interface{}, got %T", data["usage"]) + } + totalTokens, ok := usage["total_tokens"].(float64) + if !ok { + t.Fatalf("Failed to assert 'total_tokens' as float64, got %T", usage["total_tokens"]) + } + if int(totalTokens) != 15 { + t.Errorf("Expected total_tokens 15, got %d", int(totalTokens)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + jsonBytes, err := sonic.Marshal(tt.chunk) + if err != nil { + t.Fatalf("Serialization failed: %v", err) + } + var data map[string]interface{} + if err := json.Unmarshal(jsonBytes, &data); err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + tt.verify(t, data) + }) + } +} + +// TestStreamChunkAccumulation tests that chunks can be accumulated and reconstructed +func TestStreamChunkAccumulation(t *testing.T) { + t.Parallel() + + t.Run("single image multiple chunks", func(t *testing.T) { + t.Parallel() + + originalB64 := strings.Repeat("ABCD", 50000) // ~200KB + chunkSize := 64 * 1024 + + // Simulate streaming chunks + var chunks []*schemas.BifrostImageGenerationStreamResponse + chunkIdx := 0 + for offset := 0; offset < len(originalB64); offset += chunkSize { + end := offset + chunkSize + if end > len(originalB64) { + end = len(originalB64) + } + + isLast := end >= len(originalB64) + chunkType := "image_generation.partial_image" + if isLast { + chunkType = "image_generation.completed" + } + + chunks = append(chunks, &schemas.BifrostImageGenerationStreamResponse{ + ID: "img-123", + Type: chunkType, + Index: 0, + ChunkIndex: chunkIdx, + PartialB64: originalB64[offset:end], + }) + chunkIdx++ + } + + // Reconstruct from chunks + var accumulated strings.Builder + for _, chunk := range chunks { + accumulated.WriteString(chunk.PartialB64) + } + + if accumulated.String() != originalB64 { + t.Errorf("Reconstructed data doesn't match original (len %d vs %d)", + len(accumulated.String()), len(originalB64)) + } + }) + + t.Run("multiple images parallel chunks", func(t *testing.T) { + t.Parallel() + + image0Data := strings.Repeat("A", 100000) + image1Data := strings.Repeat("B", 100000) + chunkSize := 64 * 1024 + + // Interleaved chunks from 2 images + var allChunks []*schemas.BifrostImageGenerationStreamResponse + + // Image 0 chunks + for i, offset := 0, 0; offset < len(image0Data); i, offset = i+1, offset+chunkSize { + end := offset + chunkSize + if end > len(image0Data) { + end = len(image0Data) + } + allChunks = append(allChunks, &schemas.BifrostImageGenerationStreamResponse{ + Index: 0, + ChunkIndex: i, + PartialB64: image0Data[offset:end], + }) + } + + // Image 1 chunks + for i, offset := 0, 0; offset < len(image1Data); i, offset = i+1, offset+chunkSize { + end := offset + chunkSize + if end > len(image1Data) { + end = len(image1Data) + } + allChunks = append(allChunks, &schemas.BifrostImageGenerationStreamResponse{ + Index: 1, + ChunkIndex: i, + PartialB64: image1Data[offset:end], + }) + } + + // Group by image index and sort by chunk index + imageChunks := make(map[int][]*schemas.BifrostImageGenerationStreamResponse) + for _, chunk := range allChunks { + imageChunks[chunk.Index] = append(imageChunks[chunk.Index], chunk) + } + + for imgIdx := range imageChunks { + sort.Slice(imageChunks[imgIdx], func(i, j int) bool { + return imageChunks[imgIdx][i].ChunkIndex < imageChunks[imgIdx][j].ChunkIndex + }) + } + + // Reconstruct each image + var image0Reconstructed, image1Reconstructed strings.Builder + for _, chunk := range imageChunks[0] { + image0Reconstructed.WriteString(chunk.PartialB64) + } + for _, chunk := range imageChunks[1] { + image1Reconstructed.WriteString(chunk.PartialB64) + } + + if image0Reconstructed.String() != image0Data { + t.Errorf("Image 0 reconstruction failed") + } + if image1Reconstructed.String() != image1Data { + t.Errorf("Image 1 reconstruction failed") + } + }) + + t.Run("out of order chunks sorted correctly", func(t *testing.T) { + t.Parallel() + + chunks := []*schemas.BifrostImageGenerationStreamResponse{ + {ChunkIndex: 3, PartialB64: "D"}, + {ChunkIndex: 0, PartialB64: "A"}, + {ChunkIndex: 2, PartialB64: "C"}, + {ChunkIndex: 1, PartialB64: "B"}, + } + + sort.Slice(chunks, func(i, j int) bool { + return chunks[i].ChunkIndex < chunks[j].ChunkIndex + }) + + var result strings.Builder + for _, c := range chunks { + result.WriteString(c.PartialB64) + } + + if result.String() != "ABCD" { + t.Errorf("Expected 'ABCD', got '%s'", result.String()) + } + }) +} + +// TestStreamChunkUsageOnFinal tests that usage is only on final chunk +func TestStreamChunkUsageOnFinal(t *testing.T) { + t.Parallel() + + chunks := []*schemas.BifrostImageGenerationStreamResponse{ + {ChunkIndex: 0, Type: "image_generation.partial_image", Usage: nil}, + {ChunkIndex: 1, Type: "image_generation.partial_image", Usage: nil}, + {ChunkIndex: 2, Type: "image_generation.completed", Usage: &schemas.ImageUsage{ + PromptTokens: 10, TotalTokens: 100, + }}, + } + + var finalUsage *schemas.ImageUsage + for _, chunk := range chunks { + if chunk.Type == "image_generation.completed" && chunk.Usage != nil { + finalUsage = chunk.Usage + } + } + + if finalUsage == nil { + t.Fatal("Expected usage on final chunk") + } + if finalUsage.TotalTokens != 100 { + t.Errorf("Expected TotalTokens 100, got %d", finalUsage.TotalTokens) + } +} + +// TestImageCacheKeyComponents tests what components should go into cache key +func TestImageCacheKeyComponents(t *testing.T) { + t.Parallel() + + // Cache key should be deterministic based on: prompt + params + req1 := &schemas.BifrostImageGenerationRequest{ + Input: &schemas.ImageGenerationInput{Prompt: "a cat"}, + Params: &schemas.ImageGenerationParameters{Size: schemas.Ptr("1024x1024")}, + } + req2 := &schemas.BifrostImageGenerationRequest{ + Input: &schemas.ImageGenerationInput{Prompt: "a cat"}, + Params: &schemas.ImageGenerationParameters{Size: schemas.Ptr("1024x1024")}, + } + req3 := &schemas.BifrostImageGenerationRequest{ + Input: &schemas.ImageGenerationInput{Prompt: "a dog"}, + Params: &schemas.ImageGenerationParameters{Size: schemas.Ptr("1024x1024")}, + } + + // Same request should produce same cache components + key1 := generateTestCacheKey(req1) + key2 := generateTestCacheKey(req2) + key3 := generateTestCacheKey(req3) + + if key1 != key2 { + t.Errorf("Identical requests should have same cache key") + } + if key1 == key3 { + t.Errorf("Different prompts should have different cache keys") + } +} + +// generateTestCacheKey simulates cache key generation (actual impl in semanticcache) +func generateTestCacheKey(req *schemas.BifrostImageGenerationRequest) string { + if req == nil || req.Input == nil { + return "" + } + + var sb strings.Builder + sb.WriteString(req.Input.Prompt) + + if req.Params != nil { + if req.Params.Size != nil { + sb.WriteString(*req.Params.Size) + } + if req.Params.Quality != nil { + sb.WriteString(*req.Params.Quality) + } + if req.Params.Style != nil { + sb.WriteString(*req.Params.Style) + } + if req.Params.N != nil { + sb.WriteString(fmt.Sprintf("%d", *req.Params.N)) + } + } + + return sb.String() +} + +// TestCacheKeyDifferentParams tests that different params produce different keys +func TestCacheKeyDifferentParams(t *testing.T) { + t.Parallel() + + baseReq := func() *schemas.BifrostImageGenerationRequest { + return &schemas.BifrostImageGenerationRequest{ + Input: &schemas.ImageGenerationInput{Prompt: "a cat"}, + Params: &schemas.ImageGenerationParameters{}, + } + } + + tests := []struct { + name string + modify func(r *schemas.BifrostImageGenerationRequest) + }{ + {"different size", func(r *schemas.BifrostImageGenerationRequest) { r.Params.Size = schemas.Ptr("512x512") }}, + {"different quality", func(r *schemas.BifrostImageGenerationRequest) { r.Params.Quality = schemas.Ptr("hd") }}, + {"different style", func(r *schemas.BifrostImageGenerationRequest) { r.Params.Style = schemas.Ptr("vivid") }}, + {"different n", func(r *schemas.BifrostImageGenerationRequest) { r.Params.N = schemas.Ptr(2) }}, + } + + baseKey := generateTestCacheKey(baseReq()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := baseReq() + tt.modify(req) + modifiedKey := generateTestCacheKey(req) + + if modifiedKey == baseKey { + t.Errorf("Param change '%s' should produce different cache key", tt.name) + } + }) + } +} diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 53ca03c58..2b982bac2 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -41,6 +41,8 @@ type TestScenarios struct { Embedding bool // Embedding functionality Reasoning bool // Reasoning/thinking functionality via Responses API ListModels bool // List available models functionality + ImageGeneration bool // Image generation functionality + ImageGenerationStream bool // Streaming image generation functionality } // ComprehensiveTestConfig extends TestConfig with additional scenarios @@ -61,6 +63,8 @@ type ComprehensiveTestConfig struct { SpeechSynthesisFallbacks []schemas.Fallback // for speech synthesis tests EmbeddingFallbacks []schemas.Fallback // for embedding tests SkipReason string // Reason to skip certain tests + ImageGenerationModel string // Model for image generation + ImageGenerationFallbacks []schemas.Fallback // Fallbacks for image generation } // ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. @@ -517,6 +521,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ PromptCachingModel: "gpt-4.1", TranscriptionModel: "whisper-1", SpeechSynthesisModel: "tts-1", + ImageGenerationModel: "dall-e-2", Scenarios: TestScenarios{ TextCompletion: false, // Not supported TextCompletionStream: false, // Not supported diff --git a/core/internal/testutil/image_generation.go b/core/internal/testutil/image_generation.go new file mode 100644 index 000000000..66db0ec4a --- /dev/null +++ b/core/internal/testutil/image_generation.go @@ -0,0 +1,126 @@ +package testutil + +import ( + "context" + "os" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunImageGenerationTest executes the end-to-end image generation test (non-streaming) +func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGeneration", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + retryConfig := GetTestRetryConfigForScenario("ImageGeneration", testConfig) + retryContext := TestRetryContext{ + ScenarioName: "ImageGeneration", + ExpectedBehavior: map[string]interface{}{}, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ImageGenerationModel, + }, + } + + expectations := GetExpectationsForScenario("ImageGeneration", testConfig, map[string]interface{}{ + "min_images": 1, + "expected_size": "1024x1024", + }) + + imageGenerationRetryConfig := ImageGenerationRetryConfig{ + MaxAttempts: retryConfig.MaxAttempts, + BaseDelay: retryConfig.BaseDelay, + MaxDelay: retryConfig.MaxDelay, + Conditions: []ImageGenerationRetryCondition{}, + OnRetry: retryConfig.OnRetry, + OnFinalFail: retryConfig.OnFinalFail, + } + // Test basic image generation + imageGenerationOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A serene Japanese garden with cherry blossoms in spring", + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + Quality: bifrost.Ptr("standard"), + ResponseFormat: bifrost.Ptr("b64_json"), + N: bifrost.Ptr(1), + }, + Fallbacks: testConfig.ImageGenerationFallbacks, + } + + response, err := client.ImageGenerationRequest(ctx, request) + if err != nil { + return nil, err + } + if response != nil { + return response, nil + } + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: "No image generation response returned", + }, + } + } + + imageGenerationResponse, imageGenerationError := WithImageGenerationRetry(t, imageGenerationRetryConfig, retryContext, expectations, "ImageGeneration", imageGenerationOperation) + + if imageGenerationError != nil { + t.Fatalf("❌ Image generation failed: %v", GetErrorMessage(imageGenerationError)) + } + + // Validate response + if imageGenerationResponse == nil { + t.Fatal("❌ Image generation returned nil response") + } + + if len(imageGenerationResponse.Data) == 0 { + t.Fatal("❌ Image generation returned no image data") + } + + // Validate first image + imageData := imageGenerationResponse.Data[0] + if imageData.B64JSON == "" && imageData.URL == "" { + t.Fatal("❌ Image data missing both b64_json and URL") + } + + // Validate base64 if present + if imageData.B64JSON != "" { + if len(imageData.B64JSON) < 50*1000 { + t.Errorf("❌ Base64 image data too short: %d bytes (expected minimum: 50 KB for 1024x1024 image)", len(imageData.B64JSON)) + } + } + + // Validate usage if present + if imageGenerationResponse.Usage != nil { + if imageGenerationResponse.Usage.TotalTokens == 0 { + t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)") + } + } + + // Validate extra fields + if imageGenerationResponse.ExtraFields.Provider == "" { + t.Error("❌ ExtraFields.Provider is empty") + } + + if imageGenerationResponse.ExtraFields.ModelRequested == "" { + t.Error("❌ ExtraFields.ModelRequested is empty") + } + + t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d", + imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.ModelRequested, len(imageGenerationResponse.Data)) + }) +} diff --git a/core/internal/testutil/image_generation_cache.go b/core/internal/testutil/image_generation_cache.go new file mode 100644 index 000000000..e9548e1c2 --- /dev/null +++ b/core/internal/testutil/image_generation_cache.go @@ -0,0 +1,140 @@ +package testutil + +import ( + "context" + "os" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunImageGenerationCacheTest tests cache hit/miss scenarios +func RunImageGenerationCacheTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation cache test skipped: not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationCache", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Use a unique prompt for cache testing + cacheTestPrompt := "A unique test image for cache validation - " + time.Now().Format("20060102150405") + + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: cacheTestPrompt, + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + ResponseFormat: bifrost.Ptr("b64_json"), + }, + } + + // First request - should be a cache miss + start1 := time.Now() + response1, err1 := client.ImageGenerationRequest(ctx, request) + duration1 := time.Since(start1) + + if err1 != nil { + t.Fatalf("❌ First image generation request failed: %v", GetErrorMessage(err1)) + } + + if response1 == nil || len(response1.Data) == 0 { + t.Fatal("❌ First request returned no image data") + } + + // Check cache debug info if available + cacheHit1 := false + if response1.ExtraFields.CacheDebug != nil { + cacheHit1 = response1.ExtraFields.CacheDebug.CacheHit + } + + if cacheHit1 { + t.Logf("⚠️ First request was a cache hit (unexpected, but may be valid)") + } else { + t.Logf("✅ First request was a cache miss (expected)") + } + + // Second request with same prompt - should be a cache hit + start2 := time.Now() + response2, err2 := client.ImageGenerationRequest(ctx, request) + duration2 := time.Since(start2) + + if err2 != nil { + t.Fatalf("❌ Second image generation request failed: %v", GetErrorMessage(err2)) + } + + if response2 == nil || len(response2.Data) == 0 { + t.Fatal("❌ Second request returned no image data") + } + + // Check cache debug info + cacheHit2 := false + if response2.ExtraFields.CacheDebug != nil { + cacheHit2 = response2.ExtraFields.CacheDebug.CacheHit + } + + if cacheHit2 { + t.Logf("✅ Second request was a cache hit (expected)") + + // Cache hit should be faster + if duration2 < duration1 { + t.Logf("✅ Cache hit was faster: %v vs %v", duration2, duration1) + } else { + t.Logf("⚠️ Cache hit was not faster: %v vs %v (may be due to network variance)", duration2, duration1) + } + + // Validate cached response matches original + if len(response1.Data) == len(response2.Data) { + // Compare image data (should be identical for cache hit) + if response1.Data[0].B64JSON != "" && response2.Data[0].B64JSON != "" { + if response1.Data[0].B64JSON == response2.Data[0].B64JSON { + t.Logf("✅ Cached image data matches original") + } else { + t.Errorf("❌ Cached image data does not match original") + } + } + } + } else { + t.Logf("⚠️ Second request was a cache miss (cache may not be enabled or TTL expired)") + } + + // Test with different prompt - should be cache miss + request2 := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A different prompt for cache miss test", + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + ResponseFormat: bifrost.Ptr("b64_json"), + }, + } + + response3, err3 := client.ImageGenerationRequest(ctx, request2) + if err3 != nil { + t.Fatalf("❌ Third image generation request failed: %v", GetErrorMessage(err3)) + } + + cacheHit3 := false + if response3.ExtraFields.CacheDebug != nil { + cacheHit3 = response3.ExtraFields.CacheDebug.CacheHit + } + + if cacheHit3 { + t.Logf("⚠️ Different prompt was a cache hit (unexpected)") + } else { + t.Logf("✅ Different prompt was a cache miss (expected)") + } + + t.Logf("✅ Cache test completed: First=%v, Second=%v, Different=%v", cacheHit1, cacheHit2, cacheHit3) + }) +} diff --git a/core/internal/testutil/image_generation_errors.go b/core/internal/testutil/image_generation_errors.go new file mode 100644 index 000000000..51f9ddcd4 --- /dev/null +++ b/core/internal/testutil/image_generation_errors.go @@ -0,0 +1,154 @@ +package testutil + +import ( + "context" + "os" + "strings" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunImageGenerationErrorTest tests error handling scenarios +func RunImageGenerationErrorTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation error test skipped: not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationErrors", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Test 1: Empty prompt (should fail) + t.Run("EmptyPrompt", func(t *testing.T) { + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "", + }, + } + + response, bifrostErr := client.ImageGenerationRequest(ctx, request) + if bifrostErr == nil { + t.Error("❌ Empty prompt should return an error") + } else { + errorMsg := GetErrorMessage(bifrostErr) + if strings.Contains(strings.ToLower(errorMsg), "prompt") || + strings.Contains(strings.ToLower(errorMsg), "required") || + strings.Contains(strings.ToLower(errorMsg), "empty") { + t.Logf("✅ Empty prompt correctly rejected: %s", errorMsg) + } else { + t.Logf("⚠️ Empty prompt rejected but with unexpected error: %s", errorMsg) + } + } + if response != nil { + t.Error("❌ Empty prompt should not return a response") + } + }) + + // Test 2: Invalid size (should fail or use default) + t.Run("InvalidSize", func(t *testing.T) { + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A test image", + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("9999x9999"), // Invalid size + }, + } + + response, bifrostErr := client.ImageGenerationRequest(ctx, request) + if bifrostErr != nil { + errorMsg := GetErrorMessage(bifrostErr) + if strings.Contains(strings.ToLower(errorMsg), "size") || + strings.Contains(strings.ToLower(errorMsg), "invalid") { + t.Logf("✅ Invalid size correctly rejected: %s", errorMsg) + } else { + t.Logf("⚠️ Invalid size rejected but with unexpected error: %s", errorMsg) + } + } else { + // Some providers may accept and use default size + t.Logf("⚠️ Invalid size was accepted (provider may use default)") + if response != nil && len(response.Data) > 0 { + t.Logf("✅ Request succeeded with default size") + } + } + }) + + // Test 3: Invalid n parameter (too many images) + t.Run("InvalidN", func(t *testing.T) { + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A test image", + }, + Params: &schemas.ImageGenerationParameters{ + N: bifrost.Ptr(20), // Too many (max is usually 10) + }, + } + + response, bifrostErr := client.ImageGenerationRequest(ctx, request) + if bifrostErr != nil { + errorMsg := GetErrorMessage(bifrostErr) + if strings.Contains(strings.ToLower(errorMsg), "n") || + strings.Contains(strings.ToLower(errorMsg), "invalid") || + strings.Contains(strings.ToLower(errorMsg), "maximum") { + t.Logf("✅ Invalid n parameter correctly rejected: %s", errorMsg) + } else { + t.Logf("⚠️ Invalid n rejected but with unexpected error: %s", errorMsg) + } + } else { + // Some providers may cap it + t.Logf("⚠️ Invalid n was accepted (provider may cap to max)") + if response != nil { + actualN := len(response.Data) + if actualN <= 10 { + t.Logf("✅ Provider capped n to %d (expected)", actualN) + } + } + } + }) + + // Test 4: Very long prompt (may hit rate limits or content policy) + t.Run("VeryLongPrompt", func(t *testing.T) { + longPrompt := strings.Repeat("A beautiful landscape with mountains and rivers. ", 100) + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: longPrompt, + }, + } + + response, bifrostErr := client.ImageGenerationRequest(ctx, request) + if bifrostErr != nil { + errorMsg := GetErrorMessage(bifrostErr) + if strings.Contains(strings.ToLower(errorMsg), "length") || + strings.Contains(strings.ToLower(errorMsg), "too long") || + strings.Contains(strings.ToLower(errorMsg), "limit") { + t.Logf("✅ Very long prompt correctly rejected: %s", errorMsg) + } else if strings.Contains(strings.ToLower(errorMsg), "rate limit") || + strings.Contains(strings.ToLower(errorMsg), "quota") { + t.Logf("⚠️ Rate limit hit (expected for long prompts): %s", errorMsg) + } else { + t.Logf("⚠️ Long prompt rejected with unexpected error: %s", errorMsg) + } + } else { + // Some providers may accept long prompts + t.Logf("✅ Very long prompt was accepted") + if response != nil && len(response.Data) > 0 { + t.Logf("✅ Request succeeded with long prompt") + } + } + }) + + t.Logf("✅ Error handling tests completed") + }) +} diff --git a/core/internal/testutil/image_generation_fallback.go b/core/internal/testutil/image_generation_fallback.go new file mode 100644 index 000000000..4888b6f99 --- /dev/null +++ b/core/internal/testutil/image_generation_fallback.go @@ -0,0 +1,80 @@ +package testutil + +import ( + "context" + "os" + "strings" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunImageGenerationFallbackTest tests fallback to secondary provider +func RunImageGenerationFallbackTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if len(testConfig.ImageGenerationFallbacks) == 0 { + t.Logf("Image generation fallback test skipped: no fallbacks configured for provider %s", testConfig.Provider) + return + } + + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationFallback", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Create request with primary provider that should fail, triggering fallback + // Note: This test assumes the primary provider will fail (e.g., invalid key) + // In practice, you might need to configure a failing provider for this test + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A beautiful mountain landscape at dawn", + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + ResponseFormat: bifrost.Ptr("b64_json"), + }, + Fallbacks: testConfig.ImageGenerationFallbacks, + } + + response, bifrostErr := client.ImageGenerationRequest(ctx, request) + + // If primary provider fails, fallback should be used + // This test validates that fallback mechanism works + if bifrostErr != nil { + // Check if error indicates fallback was attempted + errorMsg := GetErrorMessage(bifrostErr) + if strings.Contains(strings.ToLower(errorMsg), "fallback") { + t.Logf("✅ Fallback mechanism triggered (expected behavior)") + } else { + // If we have fallbacks configured, the request should succeed via fallback + // If it still fails, log it but don't fail the test (provider-specific) + t.Logf("⚠️ Request failed even with fallbacks: %v", errorMsg) + } + return + } + + // If we get here, request succeeded (either primary or fallback) + if response == nil { + t.Fatal("❌ Image generation returned nil response") + } + + // Validate that we got a response from either primary or fallback provider + if response.ExtraFields.Provider == "" { + t.Error("❌ Response missing provider information") + } + + // Log which provider was used + t.Logf("✅ Image generation succeeded via provider: %s (may be fallback)", response.ExtraFields.Provider) + + if len(response.Data) > 0 { + t.Logf("✅ Received %d image(s) from provider %s", len(response.Data), response.ExtraFields.Provider) + } + }) +} diff --git a/core/internal/testutil/image_generation_load.go b/core/internal/testutil/image_generation_load.go new file mode 100644 index 000000000..05c6df18a --- /dev/null +++ b/core/internal/testutil/image_generation_load.go @@ -0,0 +1,329 @@ +package testutil + +import ( + "context" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/providers/openai" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunImageGenerationLoadTest tests concurrent image generation requests +func RunImageGenerationLoadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation load test skipped: not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationLoad", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + const numConcurrentRequests = 10 + const numRequestsPerGoroutine = 2 + totalRequests := numConcurrentRequests * numRequestsPerGoroutine + + var successCount int64 + var errorCount int64 + var totalDuration time.Duration + var mu sync.Mutex + + start := time.Now() + var wg sync.WaitGroup + + // Launch concurrent requests + for i := 0; i < numConcurrentRequests; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + for j := 0; j < numRequestsPerGoroutine; j++ { + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A test image for load testing - " + time.Now().Format("20060102150405"), + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), // Smaller size for faster generation + ResponseFormat: bifrost.Ptr("b64_json"), + N: bifrost.Ptr(1), + }, + } + + reqStart := time.Now() + response, bifrostErr := client.ImageGenerationRequest(ctx, request) + reqDuration := time.Since(reqStart) + + mu.Lock() + totalDuration += reqDuration + mu.Unlock() + + if bifrostErr != nil { + atomic.AddInt64(&errorCount, 1) + t.Logf("⚠️ Request %d-%d failed: %v", goroutineID, j, GetErrorMessage(bifrostErr)) + } else if response != nil && len(response.Data) > 0 { + atomic.AddInt64(&successCount, 1) + } else { + atomic.AddInt64(&errorCount, 1) + } + } + }(i) + } + + wg.Wait() + totalTime := time.Since(start) + + // Calculate statistics + avgDuration := totalDuration / time.Duration(totalRequests) + successRate := float64(successCount) / float64(totalRequests) * 100 + + t.Logf("✅ Load test completed:") + t.Logf(" Total requests: %d", totalRequests) + t.Logf(" Successful: %d (%.2f%%)", successCount, successRate) + t.Logf(" Failed: %d", errorCount) + t.Logf(" Total time: %v", totalTime) + t.Logf(" Average request duration: %v", avgDuration) + t.Logf(" Requests per second: %.2f", float64(totalRequests)/totalTime.Seconds()) + + // Validate results + if successRate < 80.0 { + t.Errorf("❌ Success rate too low: %.2f%% (expected >= 80%%)", successRate) + } else { + t.Logf("✅ Success rate acceptable: %.2f%%", successRate) + } + }) +} + +// RunImageGenerationStreamLoadTest tests stream memory usage under load +func RunImageGenerationStreamLoadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation stream load test skipped: not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationStreamLoad", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + const numConcurrentStreams = 5 + const streamsPerGoroutine = 1 + + var successCount int64 + var errorCount int64 + var totalChunks int64 + var mu sync.Mutex + + start := time.Now() + var wg sync.WaitGroup + + // Launch concurrent streams + for i := 0; i < numConcurrentStreams; i++ { + wg.Add(1) + go func(streamID int) { + defer wg.Done() + + for j := 0; j < streamsPerGoroutine; j++ { + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A streaming test image for load testing", + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + ResponseFormat: bifrost.Ptr("b64_json"), + N: bifrost.Ptr(1), + }, + } + + // Create derived context for this stream + streamCtx, cancel := context.WithCancel(ctx) + + streamChan, bifrostErr := client.ImageGenerationStreamRequest(streamCtx, request) + if bifrostErr != nil { + cancel() + atomic.AddInt64(&errorCount, 1) + t.Logf("⚠️ Stream %d-%d failed to start: %v", streamID, j, GetErrorMessage(bifrostErr)) + continue + } + + if streamChan == nil { + cancel() + atomic.AddInt64(&errorCount, 1) + t.Logf("⚠️ Stream %d-%d returned nil channel", streamID, j) + continue + } + + // Collect chunks + chunkCount := int64(0) + completed := false + + // Process stream until completion or error + for stream := range streamChan { + if stream.BifrostError != nil { + t.Logf("⚠️ Stream %d-%d error: %v", streamID, j, GetErrorMessage(stream.BifrostError)) + cancel() + continue + } + + if stream.BifrostImageGenerationStreamResponse != nil { + chunkCount++ + if stream.BifrostImageGenerationStreamResponse.Type == string(openai.ImageGenerationCompleted) { + completed = true + cancel() + continue + } + } + } + + cancel() + + mu.Lock() + totalChunks += chunkCount + mu.Unlock() + + if completed { + atomic.AddInt64(&successCount, 1) + } else { + atomic.AddInt64(&errorCount, 1) + } + } + }(i) + } + + wg.Wait() + totalTime := time.Since(start) + + avgChunksPerStream := float64(totalChunks) / float64(numConcurrentStreams*streamsPerGoroutine) + successRate := float64(successCount) / float64(numConcurrentStreams*streamsPerGoroutine) * 100 + + t.Logf("✅ Stream load test completed:") + t.Logf(" Total streams: %d", numConcurrentStreams*streamsPerGoroutine) + t.Logf(" Successful: %d (%.2f%%)", successCount, successRate) + t.Logf(" Failed: %d", errorCount) + t.Logf(" Total chunks: %d", totalChunks) + t.Logf(" Average chunks per stream: %.2f", avgChunksPerStream) + t.Logf(" Total time: %v", totalTime) + + if successRate < 80.0 { + t.Errorf("❌ Stream success rate too low: %.2f%% (expected >= 80%%)", successRate) + } else { + t.Logf("✅ Stream success rate acceptable: %.2f%%", successRate) + } + }) +} + +// RunImageGenerationCacheLoadTest tests cache performance at scale +func RunImageGenerationCacheLoadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation cache load test skipped: not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationCacheLoad", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Generate a unique prompt that will be cached + cachePrompt := "Cache load test image - " + time.Now().Format("20060102150405") + + // First request to populate cache + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: cachePrompt, + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + ResponseFormat: bifrost.Ptr("b64_json"), + }, + } + + _, err := client.ImageGenerationRequest(ctx, request) + if err != nil { + t.Fatalf("❌ Failed to populate cache: %v", GetErrorMessage(err)) + } + + // Wait a bit for cache to be written + time.Sleep(1 * time.Second) + + // Now test concurrent cache hits + const numConcurrentRequests = 20 + var cacheHitCount int64 + var cacheMissCount int64 + var totalDuration time.Duration + var mu sync.Mutex + + start := time.Now() + var wg sync.WaitGroup + + for i := 0; i < numConcurrentRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // Create new request for each goroutine to avoid data race + localRequest := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: cachePrompt, + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + ResponseFormat: bifrost.Ptr("b64_json"), + }, + } + reqStart := time.Now() + response, bifrostErr := client.ImageGenerationRequest(ctx, localRequest) + reqDuration := time.Since(reqStart) + + mu.Lock() + totalDuration += reqDuration + mu.Unlock() + + if bifrostErr != nil { + t.Logf("⚠️ Cache load test request failed: %v", GetErrorMessage(bifrostErr)) + return + } + + if response != nil && response.ExtraFields.CacheDebug != nil { + if response.ExtraFields.CacheDebug.CacheHit { + atomic.AddInt64(&cacheHitCount, 1) + } else { + atomic.AddInt64(&cacheMissCount, 1) + } + } + }() + } + + wg.Wait() + totalTime := time.Since(start) + + avgDuration := totalDuration / time.Duration(numConcurrentRequests) + cacheHitRate := float64(cacheHitCount) / float64(numConcurrentRequests) * 100 + + t.Logf("✅ Cache load test completed:") + t.Logf(" Total requests: %d", numConcurrentRequests) + t.Logf(" Cache hits: %d (%.2f%%)", cacheHitCount, cacheHitRate) + t.Logf(" Cache misses: %d", cacheMissCount) + t.Logf(" Total time: %v", totalTime) + t.Logf(" Average request duration: %v", avgDuration) + t.Logf(" Requests per second: %.2f", float64(numConcurrentRequests)/totalTime.Seconds()) + + if cacheHitRate > 50.0 { + t.Logf("✅ Cache hit rate acceptable: %.2f%%", cacheHitRate) + } else { + t.Logf("⚠️ Cache hit rate lower than expected: %.2f%% (cache may not be enabled)", cacheHitRate) + } + }) +} diff --git a/core/internal/testutil/image_generation_stream.go b/core/internal/testutil/image_generation_stream.go new file mode 100644 index 000000000..76b28c00f --- /dev/null +++ b/core/internal/testutil/image_generation_stream.go @@ -0,0 +1,133 @@ +package testutil + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/providers/openai" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunImageGenerationStreamTest executes the end-to-end streaming image generation test +func RunImageGenerationStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if testConfig.ImageGenerationModel == "" { + t.Logf("Image generation streaming not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ImageGenerationStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + retryConfig := GetTestRetryConfigForScenario("ImageGenerationStream", testConfig) + retryContext := TestRetryContext{ + ScenarioName: "ImageGenerationStream", + ExpectedBehavior: map[string]interface{}{ + "should_generate_images": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ImageGenerationModel, + }, + } + + request := &schemas.BifrostImageGenerationRequest{ + Provider: testConfig.Provider, + Model: testConfig.ImageGenerationModel, + Input: &schemas.ImageGenerationInput{ + Prompt: "A futuristic cityscape at sunset with flying cars", + }, + Params: &schemas.ImageGenerationParameters{ + Size: bifrost.Ptr("1024x1024"), + Quality: bifrost.Ptr("hd"), + ResponseFormat: bifrost.Ptr("b64_json"), + N: bifrost.Ptr(1), + }, + Fallbacks: testConfig.ImageGenerationFallbacks, + } + + validationResult := WithImageGenerationStreamRetry( + t, + retryConfig, + retryContext, + func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return client.ImageGenerationStreamRequest(ctx, request) + }, + func(responseChannel chan *schemas.BifrostStream) ImageGenerationStreamValidationResult { + // Validate stream content + var receivedData bool + var streamErrors []string + var validationErrors []string + hasCompleted := false + + streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + + for { + select { + case response, ok := <-responseChannel: + if !ok { + goto streamComplete + } + + if response == nil { + streamErrors = append(streamErrors, "Received nil stream response") + continue + } + + if response.BifrostError != nil { + streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError))) + continue + } + + if response.BifrostImageGenerationStreamResponse != nil { + receivedData = true + imgResp := response.BifrostImageGenerationStreamResponse + + if imgResp.Type == string(openai.ImageGenerationCompleted) { + hasCompleted = true + } + } + case <-streamCtx.Done(): + validationErrors = append(validationErrors, "Stream validation timed out") + goto streamComplete + } + } + streamComplete: + + passed := receivedData && hasCompleted && len(validationErrors) == 0 + if !receivedData { + validationErrors = append(validationErrors, "No stream data received") + } + if !hasCompleted { + validationErrors = append(validationErrors, "No completion chunk received") + } + + return ImageGenerationStreamValidationResult{ + Passed: passed, + Errors: validationErrors, + ReceivedData: receivedData, + StreamErrors: streamErrors, + } + }, + ) + + if !validationResult.Passed { + allErrors := append(validationResult.Errors, validationResult.StreamErrors...) + t.Fatalf("❌ Image generation stream validation failed: %s", strings.Join(allErrors, "; ")) + } + + if !validationResult.ReceivedData { + t.Fatal("❌ No stream data received") + } + + t.Logf("✅ Image generation stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d", + validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors)) + }) +} diff --git a/core/internal/testutil/response_validation.go b/core/internal/testutil/response_validation.go index b5d71b14f..d1492fa5e 100644 --- a/core/internal/testutil/response_validation.go +++ b/core/internal/testutil/response_validation.go @@ -233,6 +233,44 @@ func ValidateSpeechResponse(t *testing.T, response *schemas.BifrostSpeechRespons // Log results logValidationResults(t, result, scenarioName) + return result + +} + +// ValidateImageGenerationResponse performs comprehensive validation for image generation responses +func ValidateImageGenerationResponse(t *testing.T, response *schemas.BifrostImageGenerationResponse, err *schemas.BifrostError, expectations ResponseExpectations, scenarioName string) ValidationResult { + result := ValidationResult{ + Passed: true, + Errors: make([]string, 0), + Warnings: make([]string, 0), + MetricsCollected: make(map[string]interface{}), + } + + // If there's an error when we expected success, that's a failure + if err != nil { + result.Passed = false + parsed := ParseBifrostError(err) + result.Errors = append(result.Errors, fmt.Sprintf("Got error when expecting success: %s", FormatErrorConcise(parsed))) + LogError(t, err, scenarioName) + return result + } + + // If response is nil when we expected success, that's a failure + if response == nil { + result.Passed = false + result.Errors = append(result.Errors, "Response is nil") + return result + } + + // Validate image generation specific fields + validateImageGenerationFields(t, response, expectations, &result) + + // Collect metrics + collectImageGenerationResponseMetrics(response, &result) + + // Log results + logValidationResults(t, result, scenarioName) + return result } @@ -1019,6 +1057,81 @@ func collectTranscriptionResponseMetrics(response *schemas.BifrostTranscriptionR result.MetricsCollected["has_duration"] = response.Duration != nil } +// ============================================================================= +// VALIDATION HELPER FUNCTIONS - IMAGE GENERATION RESPONSE +// ============================================================================= + +func validateImageGenerationFields(t *testing.T, response *schemas.BifrostImageGenerationResponse, expectations ResponseExpectations, result *ValidationResult) { + // Check if response has image data + if len(response.Data) == 0 { + result.Passed = false + result.Errors = append(result.Errors, "Image generation response missing image data") + return + } + + // Check each image has either B64JSON or URL + for i, img := range response.Data { + if img.B64JSON == "" && img.URL == "" { + result.Passed = false + result.Errors = append(result.Errors, fmt.Sprintf("Image %d has no B64JSON or URL", i)) + } + } + + // Check minimum number of images if specified + if minImages, ok := expectations.ProviderSpecific["min_images"].(int); ok { + actualCount := len(response.Data) + if actualCount < minImages { + result.Passed = false + result.Errors = append(result.Errors, + fmt.Sprintf("Too few images: got %d, expected at least %d", actualCount, minImages)) + } else { + result.MetricsCollected["image_count"] = actualCount + } + } + + // Validate image size if specified + if expectedSize, ok := expectations.ProviderSpecific["expected_size"].(string); ok { + result.MetricsCollected["expected_size"] = expectedSize + // Note: Actual size validation would require downloading/decoding images + } + + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } + + result.MetricsCollected["image_generation_validation"] = "completed" +} + +func collectImageGenerationResponseMetrics(response *schemas.BifrostImageGenerationResponse, result *ValidationResult) { + result.MetricsCollected["image_count"] = len(response.Data) + result.MetricsCollected["has_images"] = len(response.Data) > 0 + + // Count images with URLs vs B64JSON + urlCount := 0 + b64Count := 0 + for _, img := range response.Data { + if img.URL != "" { + urlCount++ + } + if img.B64JSON != "" { + b64Count++ + } + } + result.MetricsCollected["images_with_url"] = urlCount + result.MetricsCollected["images_with_b64"] = b64Count + + if response.Usage != nil { + result.MetricsCollected["prompt_tokens"] = response.Usage.PromptTokens + result.MetricsCollected["total_tokens"] = response.Usage.TotalTokens + } +} + // ============================================================================= // VALIDATION HELPER FUNCTIONS - EMBEDDING RESPONSE // ============================================================================= diff --git a/core/internal/testutil/test_retry_conditions.go b/core/internal/testutil/test_retry_conditions.go index 908bdc025..b18673a6d 100644 --- a/core/internal/testutil/test_retry_conditions.go +++ b/core/internal/testutil/test_retry_conditions.go @@ -842,3 +842,40 @@ func (c *InvalidEmbeddingDimensionCondition) ShouldRetry(response *schemas.Bifro func (c *InvalidEmbeddingDimensionCondition) GetConditionName() string { return "InvalidEmbeddingDimension" } + +// ============================================================================= +// IMAGE CONDITIONS +// ============================================================================= + +// EmptyImageGenerationCondition checks for missing or invalid image image +type EmptyImageGenerationCondition struct{} + +func (c *EmptyImageGenerationCondition) ShouldRetry(response *schemas.BifrostResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) { + // If there's an error, let other conditions handle it + if err != nil { + return false, "" + } + + // No response at all + if response == nil { + return true, "response is nil" + } + + // Check if image response exists and is not empty + if response.ImageGenerationResponse == nil || len(response.ImageGenerationResponse.Data) == 0 { + return true, "response has no image data" + } + + // Check each image has either B64JSON or URL + for i, img := range response.ImageGenerationResponse.Data { + if img.B64JSON == "" && img.URL == "" { + return true, fmt.Sprintf("image %d has no B64JSON or URL", i) + } + } + + return false, "" +} + +func (c *EmptyImageGenerationCondition) GetConditionName() string { + return "EmptyImage" +} diff --git a/core/internal/testutil/test_retry_framework.go b/core/internal/testutil/test_retry_framework.go index 2d4d5750b..3570ac443 100644 --- a/core/internal/testutil/test_retry_framework.go +++ b/core/internal/testutil/test_retry_framework.go @@ -158,6 +158,12 @@ type EmbeddingRetryCondition interface { GetConditionName() string } +// ImageGenerationRetryCondition defines an interface for checking if a image generation test operation should be retried +type ImageGenerationRetryCondition interface { + ShouldRetry(response *schemas.BifrostImageGenerationResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) + GetConditionName() string +} + // ListModelsRetryCondition defines an interface for checking if a list models test operation should be retried type ListModelsRetryCondition interface { ShouldRetry(response *schemas.BifrostListModelsResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) @@ -232,6 +238,16 @@ type TranscriptionRetryConfig struct { OnFinalFail func(attempts int, finalErr error, t *testing.T) // Called on final failure } +// ImageGenerationRetryConfig configures retry behavior for image generation test scenarios +type ImageGenerationRetryConfig struct { + MaxAttempts int // Maximum retry attempts (including initial attempt) + BaseDelay time.Duration // Base delay between retries + MaxDelay time.Duration // Maximum delay between retries + Conditions []ImageGenerationRetryCondition // Conditions that trigger retries + OnRetry func(attempt int, reason string, t *testing.T) // Called before each retry + OnFinalFail func(attempts int, finalErr error, t *testing.T) // Called on final failure +} + // EmbeddingRetryConfig configures retry behavior for embedding test scenarios type EmbeddingRetryConfig struct { MaxAttempts int // Maximum retry attempts (including initial attempt) @@ -907,6 +923,22 @@ func DefaultTranscriptionRetryConfig() TestRetryConfig { } } +// DefaultImageGenerationRetryConfig creates a retry config for image tests +func DefaultImageGenerationRetryConfig() TestRetryConfig { + return TestRetryConfig{ + MaxAttempts: 10, + BaseDelay: 2000 * time.Millisecond, + MaxDelay: 10 * time.Second, + Conditions: []TestRetryCondition{ + &EmptyImageGenerationCondition{}, // Check for missing image generation data + &GenericResponseCondition{}, // Catch generic error responses + }, + OnRetry: func(attempt int, reason string, t *testing.T) { + t.Logf("🔄 Retrying image generation test (attempt %d): %s", attempt, reason) + }, + } +} + // ReasoningRetryConfig creates a retry config for reasoning tests func ReasoningRetryConfig() TestRetryConfig { return TestRetryConfig{ @@ -1141,6 +1173,8 @@ func GetTestRetryConfigForScenario(scenarioName string, testConfig Comprehensive return ReasoningRetryConfig() case "ListModels", "ListModelsPagination": return DefaultListModelsRetryConfig() + case "ImageGeneration", "ImageGenerationStream": + return DefaultImageGenerationRetryConfig() default: // For basic scenarios like SimpleChat, TextCompletion return DefaultTestRetryConfig() @@ -1838,6 +1872,182 @@ func checkTranscriptionRetryConditions(response *schemas.BifrostTranscriptionRes return false, "" } +func WithImageGenerationRetry( + t *testing.T, + config ImageGenerationRetryConfig, + context TestRetryContext, + expectations ResponseExpectations, + scenarioName string, + operation func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError), +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + + var lastResponse *schemas.BifrostImageGenerationResponse + var lastError *schemas.BifrostError + + for attempt := 1; attempt <= config.MaxAttempts; attempt++ { + context.AttemptNumber = attempt + + // Execute the operation + response, err := operation() + lastResponse = response + lastError = err + + // If we have a response, validate it FIRST + if response != nil { + validationResult := ValidateImageGenerationResponse(t, response, err, expectations, scenarioName) + + // If validation passes, we're done! + if validationResult.Passed { + return response, err + } + + // Validation failed - ALWAYS retry validation failures for functionality checks + // Network errors are handled by bifrost core, so these are content/functionality validation errors + if attempt < config.MaxAttempts { + // ALWAYS retry on timeout errors - this takes precedence over all other conditions + if err != nil && isTimeoutError(err) { + retryReason := fmt.Sprintf("❌ timeout error detected: %s", GetErrorMessage(err)) + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } + + // Calculate delay with exponential backoff + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + + // Check other retry conditions first (for logging/debugging) + shouldRetryFromConditions, conditionReason := checkImageGenerationRetryConditions(response, err, context, config.Conditions) + + // ALWAYS retry on validation failures - this is the primary purpose of these tests + // Content validation errors indicate functionality issues that should be retried + shouldRetry := len(validationResult.Errors) > 0 + var retryReason string + + if shouldRetry { + // Validation failures are the primary retry reason - ALWAYS prefix with ❌ + retryReason = fmt.Sprintf("❌ validation failure (content/functionality check): %s", strings.Join(validationResult.Errors, "; ")) + // Append condition-based reason if present for additional context + if shouldRetryFromConditions && conditionReason != "" { + retryReason += fmt.Sprintf(" | also: %s", conditionReason) + } + } else if shouldRetryFromConditions { + // Fallback to condition-based retry if no validation errors (edge case) + // Ensure ❌ prefix for consistency with error logging + shouldRetry = true + if !strings.Contains(conditionReason, "❌") { + retryReason = fmt.Sprintf("❌ %s", conditionReason) + } else { + retryReason = conditionReason + } + } + + if shouldRetry { + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } + + // Calculate delay with exponential backoff + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + } + + // All retries failed validation - create a BifrostError to force test failure + validationErrors := strings.Join(validationResult.Errors, "; ") + + if config.OnFinalFail != nil { + finalErr := fmt.Errorf("❌ validation failed after %d attempts: %s", attempt, validationErrors) + config.OnFinalFail(attempt, finalErr, t) + } + + // Return nil response + BifrostError so calling test fails + statusCode := 400 + testFailureError := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("❌ Validation failed after %d attempts: %s", attempt, validationErrors), + }, + } + return nil, testFailureError + } + + // If we have an error without a response, check if we should retry + if err != nil && attempt < config.MaxAttempts { + // ALWAYS retry on timeout errors - this takes precedence over other conditions + if isTimeoutError(err) { + retryReason := fmt.Sprintf("❌ timeout error detected: %s", GetErrorMessage(err)) + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } + + // Calculate delay with exponential backoff + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + + shouldRetry, retryReason := checkImageGenerationRetryConditions(response, err, context, config.Conditions) + + // ALWAYS retry on non-structural errors (network errors are handled by bifrost core) + // If no condition matches, still retry on any error as it's likely transient + if !shouldRetry { + shouldRetry = true + errorMsg := GetErrorMessage(err) + if !strings.Contains(errorMsg, "❌") { + errorMsg = fmt.Sprintf("❌ %s", errorMsg) + } + retryReason = fmt.Sprintf("❌ non-structural error (will retry): %s", errorMsg) + } else if !strings.Contains(retryReason, "❌") { + retryReason = fmt.Sprintf("❌ %s", retryReason) + } + + if shouldRetry { + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } + + // Calculate delay with exponential backoff + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + } + + // If we get here, either we got a final error or no more retries + break + } + + // Final failure callback + if config.OnFinalFail != nil && lastError != nil { + errorMsg := "unknown error" + if lastError.Error != nil { + errorMsg = lastError.Error.Message + } + // Ensure error message has ❌ prefix if not already present + if !strings.Contains(errorMsg, "❌") { + errorMsg = fmt.Sprintf("❌ %s", errorMsg) + } + config.OnFinalFail(config.MaxAttempts, fmt.Errorf("❌ final error: %s", errorMsg), t) + } + + return lastResponse, lastError +} + +// checkImageGenerationRetryConditions checks if any image generation retry conditions are met +func checkImageGenerationRetryConditions(response *schemas.BifrostImageGenerationResponse, err *schemas.BifrostError, context TestRetryContext, conditions []ImageGenerationRetryCondition) (bool, string) { + for _, condition := range conditions { + if shouldRetry, reason := condition.ShouldRetry(response, err, context); shouldRetry { + return true, fmt.Sprintf("%s: %s", condition.GetConditionName(), reason) + } + } + + return false, "" +} + // WithListModelsTestRetry wraps a list models test operation with retry logic // IMPORTANT: ALWAYS retries on ANY failure condition (errors, nil response, empty data, validation failures) // This ensures maximum resilience for list models tests @@ -2483,3 +2693,153 @@ func WithChatStreamValidationRetry( return lastResult } + +type ImageGenerationStreamValidationResult struct { + Passed bool + Errors []string + ReceivedData bool + StreamErrors []string + LastLatency int64 +} + +// WithImageGenerationStreamValidationRetry wraps an image generation streaming operation with retry logic that includes stream content validation +// This function wraps the entire operation (request + stream reading + validation) and retries on validation failures +func WithImageGenerationStreamRetry( + t *testing.T, + config TestRetryConfig, + context TestRetryContext, + operation func() (chan *schemas.BifrostStream, *schemas.BifrostError), + validateStream func(chan *schemas.BifrostStream) ImageGenerationStreamValidationResult) ImageGenerationStreamValidationResult { + + var lastResult ImageGenerationStreamValidationResult + + for attempt := 1; attempt <= config.MaxAttempts; attempt++ { + context.AttemptNumber = attempt + + // Execute the operation to get the stream + responseChannel, err := operation() + + // If we have an error getting the stream, check if we should retry + if err != nil { + // Log error with ❌ prefix for first attempt + if attempt == 1 { + errorMsg := GetErrorMessage(err) + if !strings.Contains(errorMsg, "❌") { + errorMsg = fmt.Sprintf("❌ %s", errorMsg) + } + t.Logf("❌ Image generation stream request failed (attempt %d/%d) for %s: %s", attempt, config.MaxAttempts, context.ScenarioName, errorMsg) + } + + // Check if we should retry + if attempt < config.MaxAttempts { + var shouldRetry bool + var retryReason string + + // ALWAYS retry on timeout errors + if isTimeoutError(err) { + shouldRetry = true + retryReason = fmt.Sprintf("❌ timeout error detected: %s", GetErrorMessage(err)) + } else { + // Check retry conditions + shouldRetryFromConditions, conditionReason := checkStreamRetryConditions(err, context, config.Conditions) + if shouldRetryFromConditions { + shouldRetry = true + if !strings.Contains(conditionReason, "❌") { + retryReason = fmt.Sprintf("❌ %s", conditionReason) + } else { + retryReason = conditionReason + } + } else { + // Retry on any error for streaming + shouldRetry = true + errorMsg := GetErrorMessage(err) + if !strings.Contains(errorMsg, "❌") { + errorMsg = fmt.Sprintf("❌ %s", errorMsg) + } + retryReason = fmt.Sprintf("❌ streaming error (will retry): %s", errorMsg) + } + } + + if shouldRetry { + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } else { + t.Logf("🔄 Retrying image generation stream request (attempt %d/%d) for %s: %s", attempt+1, config.MaxAttempts, context.ScenarioName, retryReason) + } + + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + } + + // All retries exhausted + if config.OnFinalFail != nil { + errorMsg := GetErrorMessage(err) + if !strings.Contains(errorMsg, "❌") { + errorMsg = fmt.Sprintf("❌ %s", errorMsg) + } + config.OnFinalFail(attempt, fmt.Errorf("❌ image generation stream request failed after %d attempts: %s", attempt, errorMsg), t) + } + return ImageGenerationStreamValidationResult{ + Passed: false, + Errors: []string{fmt.Sprintf("❌ stream request failed: %s", GetErrorMessage(err))}, + } + } + + if responseChannel == nil { + if attempt < config.MaxAttempts { + retryReason := "❌ response channel is nil" + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + return ImageGenerationStreamValidationResult{ + Passed: false, + Errors: []string{"❌ response channel is nil"}, + } + } + + // Validate the stream content + validationResult := validateStream(responseChannel) + lastResult = validationResult + + // If validation passes, we're done! + if validationResult.Passed { + return validationResult + } + + // Validation failed - ALWAYS retry validation failures + if attempt < config.MaxAttempts { + retryReason := fmt.Sprintf("❌ validation failure (content/functionality check): %s", strings.Join(validationResult.Errors, "; ")) + if len(validationResult.StreamErrors) > 0 { + retryReason += fmt.Sprintf(" | stream errors: %s", strings.Join(validationResult.StreamErrors, "; ")) + } + + if config.OnRetry != nil { + config.OnRetry(attempt, retryReason, t) + } else { + t.Logf("🔄 Retrying image generation stream validation (attempt %d/%d) for %s: %s", attempt+1, config.MaxAttempts, context.ScenarioName, retryReason) + } + + delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + time.Sleep(delay) + continue + } + } + + // All retries exhausted - log final failure + if config.OnFinalFail != nil { + allErrors := append(lastResult.Errors, lastResult.StreamErrors...) + errorMsg := strings.Join(allErrors, "; ") + if !strings.Contains(errorMsg, "❌") { + errorMsg = fmt.Sprintf("❌ %s", errorMsg) + } + config.OnFinalFail(config.MaxAttempts, fmt.Errorf("❌ image generation stream validation failed after %d attempts: %s", config.MaxAttempts, errorMsg), t) + } + + return lastResult +} diff --git a/core/internal/testutil/tests.go b/core/internal/testutil/tests.go index e6ab78fe2..641131058 100644 --- a/core/internal/testutil/tests.go +++ b/core/internal/testutil/tests.go @@ -51,6 +51,14 @@ func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context RunListModelsTest, RunListModelsPaginationTest, RunPromptCachingTest, + RunImageGenerationTest, + RunImageGenerationStreamTest, + RunImageGenerationFallbackTest, + RunImageGenerationCacheTest, + RunImageGenerationErrorTest, + RunImageGenerationLoadTest, + RunImageGenerationStreamLoadTest, + RunImageGenerationCacheLoadTest, } // Execute all test scenarios diff --git a/core/internal/testutil/validation_presets.go b/core/internal/testutil/validation_presets.go index c096be72f..2f928210b 100644 --- a/core/internal/testutil/validation_presets.go +++ b/core/internal/testutil/validation_presets.go @@ -3,7 +3,6 @@ package testutil import ( "regexp" - "github.com/maximhq/bifrost/core/schemas" ) @@ -186,6 +185,23 @@ func TranscriptionExpectations(minTextLength int) ResponseExpectations { } } +// In validation_presets.go - add this function +func ImageGenerationExpectations(minImages int, expectedSize string) ResponseExpectations { + return ResponseExpectations{ + ShouldHaveContent: false, // Image responses don't have text content + ExpectedChoiceCount: 0, // Image responses don't have choices + ShouldHaveUsageStats: true, + ShouldHaveTimestamps: true, + ShouldHaveModel: true, + ShouldHaveLatency: true, // Global expectation: latency should always be present + ProviderSpecific: map[string]interface{}{ + "min_images": minImages, + "expected_size": expectedSize, + "response_type": "image_generation", + }, + } +} + // ReasoningExpectations returns validation expectations for reasoning scenarios func ReasoningExpectations() ResponseExpectations { return ResponseExpectations{ @@ -286,6 +302,14 @@ func GetExpectationsForScenario(scenarioName string, testConfig ComprehensiveTes expectations.ShouldContainKeywords = []string{"unique", "specific", "capability"} return expectations + case "ImageGeneration": + if minImages, ok := customParams["min_images"].(int); ok { + if expectedSize, ok := customParams["expected_size"].(string); ok { + return ImageGenerationExpectations(minImages, expectedSize) + } + } + return ImageGenerationExpectations(1, "1024x1024") + default: // Default to basic chat expectations return BasicChatExpectations() diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index aa1dfa1fd..25900e3a2 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -629,7 +629,7 @@ func HandleAnthropicChatCompletionStreaming( response.ExtraFields.RawResponse = eventData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) } if isLastChunk { @@ -659,7 +659,7 @@ func HandleAnthropicChatCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) } }() @@ -978,10 +978,10 @@ func HandleAnthropicResponsesStream( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) } } @@ -1023,3 +1023,13 @@ func (provider *AnthropicProvider) Transcription(ctx context.Context, key schema func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Anthropic provider. +func (provider *AnthropicProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Anthropic provider. +func (provider *AnthropicProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 74bb7d72f..7a1cf3429 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/openai" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -783,6 +784,136 @@ func (provider *AzureProvider) TranscriptionStream(ctx context.Context, postHook return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } +// ImageGeneration performs an Image Generation request to Azure's API. +// It formats the request, sends it to Azure, and processes the response. +// Returns a BifrostResponse containing the bifrost response or an error if the request fails. +func (provider *AzureProvider) ImageGeneration(ctx context.Context, key schemas.Key, + request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + if err := provider.validateKeyConfig(key); err != nil { + return nil, err + } + + // Convert bifrost request to Azure format. + azureRequest := ToAzureImageRequest(request) + if azureRequest == nil { + return nil, providerUtils.NewBifrostOperationError("invalid request: input is required", nil, provider.GetProviderKey()) + } + + // Make request + providerName := provider.GetProviderKey() + + jsonData, err := sonic.Marshal(azureRequest) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("could not serialize azure image generation request", err, providerName) + } + + deployment, bifrostErr := provider.getModelDeployment(key, azureRequest.Model) + if bifrostErr != nil { + return nil, bifrostErr + } + + response, _, latency, bifrostErr := provider.completeRequest( + ctx, + jsonData, + fmt.Sprintf("openai/deployments/%s/images/generations", deployment), + key, + deployment, + request.Model, + schemas.ImageGenerationRequest, + ) + // Handle error response + if bifrostErr != nil { + return nil, bifrostErr + } + azureResponse := &AzureImageResponse{} + if err := sonic.Unmarshal(response, azureResponse); err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal Azure image response", err, schemas.Azure) + } + + // Convert Azure response to Bifrost format. + bifrostResp := ToBifrostImageResponse(azureResponse, request.Model, latency) + bifrostResp.ExtraFields.Provider = provider.GetProviderKey() + bifrostResp.ExtraFields.ModelRequested = request.Model + bifrostResp.ExtraFields.ModelDeployment = deployment + bifrostResp.ExtraFields.RequestType = schemas.ImageGenerationRequest + + // Set raw request if enabled + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequest(&bifrostResp.ExtraFields, jsonData) + } + + // Set raw response if enabled + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResp.ExtraFields.RawResponse = string(response) + } + + return bifrostResp, nil +} + +// ImageGenerationStream performs a streaming image generation request to Azure's API. +// It formats the request, sends it to Azure, and processes the response. +// Returns a channel of BifrostStream objects or an error if the request fails. +func (provider *AzureProvider) ImageGenerationStream( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + key schemas.Key, + request *schemas.BifrostImageGenerationRequest, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + // Validate api key configs + if err := provider.validateKeyConfig(key); err != nil { + return nil, err + } + + // + deployment := key.AzureKeyConfig.Deployments[request.Model] + if deployment == "" { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) + } + + apiVersion := key.AzureKeyConfig.APIVersion + if apiVersion == nil { + apiVersion = schemas.Ptr(AzureAPIVersionDefault) + } + + url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", key.AzureKeyConfig.Endpoint, deployment, *apiVersion) + + // Prepare Azure-specific headers + authHeader := make(map[string]string) + + // Set Azure authentication - either Bearer token or api-key + if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { + authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) + } else { + authHeader["api-key"] = key.Value + } + + if !openai.StreamingEnabledImageModels[request.Model] { + return nil, providerUtils.NewBifrostOperationError( + fmt.Sprintf("%s is not supported for streaming image generation", request.Model), + nil, + provider.GetProviderKey()) + } + + // Azure is OpenAI-compatible + return openai.HandleOpenAIImageGenerationStreaming( + ctx, + provider.client, + url, + request, + authHeader, + provider.networkConfig.ExtraHeaders, + providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + postHookRunner, + nil, + nil, nil, + provider.logger, + ) + +} + // validateKeyConfig validates the key configuration. // It checks if the key config is set, the endpoint is set, and the deployments are set. // Returns an error if any of the checks fail. diff --git a/core/providers/azure/images.go b/core/providers/azure/images.go new file mode 100644 index 000000000..0185eb298 --- /dev/null +++ b/core/providers/azure/images.go @@ -0,0 +1,73 @@ +package azure + +import ( + "time" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" +) + +// ToAzureImageRequest converts a Bifrost Image Request to Azure format +func ToAzureImageRequest(bifrostReq *schemas.BifrostImageGenerationRequest) *AzureImageRequest { + if bifrostReq == nil || bifrostReq.Input == nil { + return nil + } + + req := &AzureImageRequest{ + Model: bifrostReq.Model, + Prompt: bifrostReq.Input.Prompt, + } + + mapImageParams(bifrostReq.Params, req) + return req +} + +// This function maps Image generation parameters from a Bifrost Request to Azure format +func mapImageParams(p *schemas.ImageGenerationParameters, req *AzureImageRequest) { + if p == nil { + return + } + req.N = p.N + req.Size = p.Size + req.Quality = p.Quality + req.Style = p.Style + req.ResponseFormat = p.ResponseFormat + req.User = p.User +} + +// ToBifrostImageResponse converts an Azure Image Response to Bifrost format +func ToBifrostImageResponse(azureResponse *AzureImageResponse, requestModel string, latency time.Duration) *schemas.BifrostImageGenerationResponse { + if azureResponse == nil { + return nil + } + + data := make([]schemas.ImageData, len(azureResponse.Data)) + for i, img := range azureResponse.Data { + data[i] = schemas.ImageData{ + URL: img.URL, + B64JSON: img.B64JSON, + RevisedPrompt: img.RevisedPrompt, + Index: i, + } + } + + var usage *schemas.ImageUsage + if azureResponse.Usage != nil { + usage = &schemas.ImageUsage{ + PromptTokens: azureResponse.Usage.InputTokens, + TotalTokens: azureResponse.Usage.TotalTokens, + } + } + + return &schemas.BifrostImageGenerationResponse{ + ID: uuid.NewString(), + Created: azureResponse.Created, + Model: requestModel, + Data: data, + Usage: usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Azure, + Latency: latency.Milliseconds(), + }, + } +} diff --git a/core/providers/azure/types.go b/core/providers/azure/types.go index af113507b..1e3712b2d 100644 --- a/core/providers/azure/types.go +++ b/core/providers/azure/types.go @@ -33,3 +33,37 @@ type AzureListModelsResponse struct { Object string `json:"object"` Data []AzureModel `json:"data"` } + +// AzureImageRequest is the struct for Image Generation requests by Azure. +type AzureImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N *int `json:"n,omitempty"` + Size *string `json:"size,omitempty"` + Quality *string `json:"quality,omitempty"` + Style *string `json:"style,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` + User *string `json:"user,omitempty"` +} + +// AzureImageResponse is the struct for Image Generation responses by Azure. +type AzureImageResponse struct { + Created int64 `json:"created"` + Data []struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + } `json:"data"` + Usage *AzureImageGenerationUsage `json:"usage"` +} + +type AzureImageGenerationUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + + InputTokensDetails *struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` + } `json:"input_tokens_details,omitempty"` +} diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 9dbcb8a12..9204a7cc8 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -664,7 +664,7 @@ func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postH }, } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(textResponse, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(textResponse, nil, nil, nil, nil, nil), responseChan) } } }() @@ -891,7 +891,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH response.ExtraFields.RawResponse = string(message.Payload) } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) } } } @@ -905,7 +905,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) }() return responseChan, nil @@ -1063,7 +1063,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil, nil), responseChan) } break } @@ -1157,7 +1157,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu response.ExtraFields.RawResponse = string(message.Payload) } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) } } } @@ -1283,6 +1283,16 @@ func (provider *BedrockProvider) TranscriptionStream(ctx context.Context, postHo return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, schemas.Bedrock) } +// ImageGeneration is not supported by the Bedrock provider. +func (provider *BedrockProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Bedrock provider. +func (provider *BedrockProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} + func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) (string, string) { // Format the path with proper model identifier for streaming path := fmt.Sprintf("%s/%s", model, basePath) diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index 68cb9fa45..00bb066de 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -210,3 +210,13 @@ func (provider *CerebrasProvider) Transcription(ctx context.Context, key schemas func (provider *CerebrasProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Cerebras provider. +func (provider *CerebrasProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Cerebras provider. +func (provider *CerebrasProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index f5171971a..f2594eaa8 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -483,10 +483,10 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) break } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) } } } @@ -725,10 +725,10 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) } } @@ -817,3 +817,13 @@ func (provider *CohereProvider) Transcription(ctx context.Context, key schemas.K func (provider *CohereProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Cohere provider. +func (provider *CohereProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Cohere provider. +func (provider *CohereProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index 50e475042..8922d746b 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -416,7 +416,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRu response.ExtraFields.RawResponse = audioChunk } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan) } } @@ -438,7 +438,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRu providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody) } ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil, nil), responseChan) }() return responseChan, nil @@ -665,6 +665,16 @@ func (provider *ElevenlabsProvider) TranscriptionStream(ctx context.Context, pos return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } +// ImageGeneration is not supported by the Elevenlabs provider. +func (provider *ElevenlabsProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Elevenlabs provider. +func (provider *ElevenlabsProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} + // buildSpeechRequestURL constructs the full request URL using the provider's configuration for speech. func (provider *ElevenlabsProvider) buildBaseSpeechRequestURL(ctx context.Context, defaultPath string, requestType schemas.RequestType, request *schemas.BifrostSpeechRequest) string { baseURL := provider.networkConfig.BaseURL diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index fb293a9e8..f33e7ea2e 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -491,12 +491,12 @@ func HandleGeminiChatCompletionStream( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) break } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) } } @@ -799,7 +799,7 @@ func HandleGeminiResponsesStream( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) return } @@ -809,7 +809,7 @@ func HandleGeminiResponsesStream( } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) } } } @@ -851,7 +851,7 @@ func HandleGeminiResponsesStream( } ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil, nil), responseChan) } } }() @@ -1112,7 +1112,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan) } } @@ -1138,7 +1138,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan) } }() @@ -1382,7 +1382,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo } // Process response through post-hooks and send to channel - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) } } @@ -1414,13 +1414,23 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) } }() return responseChan, nil } +// ImageGeneration is not supported by the Gemini provider. +func (provider *GeminiProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Gemini provider. +func (provider *GeminiProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} + // processGeminiStreamChunk processes a single chunk from Gemini streaming response func processGeminiStreamChunk(jsonData string) (*GenerateContentResponse, error) { // First, check if this is an error response diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index c28386ff2..5246ef625 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -248,3 +248,13 @@ func (provider *GroqProvider) Transcription(ctx context.Context, key schemas.Key func (provider *GroqProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Groq provider. +func (provider *GroqProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Groq provider. +func (provider *GroqProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 6e093f710..83a8315e4 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -256,3 +256,13 @@ func (provider *MistralProvider) Transcription(ctx context.Context, key schemas. func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Mistral provider. +func (provider *MistralProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Mistral provider. +func (provider *MistralProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 5fdd00406..d3c427521 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -222,3 +222,13 @@ func (provider *OllamaProvider) Transcription(ctx context.Context, key schemas.K func (provider *OllamaProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Ollama provider. +func (provider *OllamaProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Ollama provider. +func (provider *OllamaProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/openai/images.go b/core/providers/openai/images.go new file mode 100644 index 000000000..fd93ef57c --- /dev/null +++ b/core/providers/openai/images.go @@ -0,0 +1,91 @@ +package openai + +import ( + "time" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + ImageGenerationPartial ImageGenerationEventType = "image_generation.partial_image" + ImageGenerationCompleted ImageGenerationEventType = "image_generation.completed" + + // ImageGenerationChunkSize is the size of base64 chunks when splitting large image data + ImageGenerationChunkSize = 128 * 1024 +) + +var StreamingEnabledImageModels = map[string]bool{ + "gpt-image-1": true, + "dall-e-2": false, + "dall-e-3": false, +} + +// ToOpenAIImageGenerationRequest converts a Bifrost Image Request to OpenAI format +func ToOpenAIImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRequest) *OpenAIImageGenerationRequest { + if bifrostReq == nil || bifrostReq.Input == nil { + return nil + } + + req := &OpenAIImageGenerationRequest{ + Model: bifrostReq.Model, + Prompt: bifrostReq.Input.Prompt, + } + + if bifrostReq.Params != nil { + req.ImageGenerationParameters = *bifrostReq.Params + } + return req +} + +// ToBifrostImageResponse converts an OpenAI Image Response to Bifrost format +func ToBifrostImageResponse(openaiResponse *OpenAIImageGenerationResponse, requestModel string, latency time.Duration) *schemas.BifrostImageGenerationResponse { + if openaiResponse == nil { + return nil + } + + data := make([]schemas.ImageData, len(openaiResponse.Data)) + for i, img := range openaiResponse.Data { + data[i] = schemas.ImageData{ + URL: img.URL, + B64JSON: img.B64JSON, + RevisedPrompt: img.RevisedPrompt, + Index: i, + } + } + + var usage *schemas.ImageUsage + if openaiResponse.Usage != nil { + usage = &schemas.ImageUsage{ + PromptTokens: openaiResponse.Usage.InputTokens, + TotalTokens: openaiResponse.Usage.TotalTokens, + } + } + + return &schemas.BifrostImageGenerationResponse{ + ID: uuid.NewString(), + Created: openaiResponse.Created, + Model: requestModel, + Data: data, + Usage: usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + Latency: latency.Milliseconds(), + }, + } +} + +// ToBifrostImageGenerationRequest converts an OpenAI image generation request to Bifrost format +func (request *OpenAIImageGenerationRequest) ToBifrostImageGenerationRequest() *schemas.BifrostImageGenerationRequest { + provider, model := schemas.ParseModelString(request.Model, schemas.OpenAI) + + return &schemas.BifrostImageGenerationRequest{ + Provider: provider, + Model: model, + Input: &schemas.ImageGenerationInput{ + Prompt: request.Prompt, + }, + Params: &request.ImageGenerationParameters, + Fallbacks: schemas.ParseFallbacks(request.Fallbacks), + } +} diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 5e40db546..4fc00cb20 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -14,6 +14,7 @@ import ( "time" "github.com/bytedance/sonic" + "github.com/google/uuid" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" @@ -575,7 +576,7 @@ func HandleOpenAITextCompletionStreaming( response.ExtraFields.RawResponse = jsonData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(&response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(&response, nil, nil, nil, nil, nil), responseChan) } // For providers that don't send [DONE] marker break on finish_reason @@ -599,7 +600,7 @@ func HandleOpenAITextCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil, nil), responseChan) } }() @@ -980,14 +981,14 @@ func HandleOpenAIChatCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) return } response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan) } } else { if postResponseConverter != nil { @@ -1061,7 +1062,7 @@ func HandleOpenAIChatCompletionStreaming( response.ExtraFields.RawResponse = jsonData } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, &response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, &response, nil, nil, nil, nil), responseChan) } // For providers that don't send [DONE] marker break on finish_reason @@ -1086,7 +1087,7 @@ func HandleOpenAIChatCompletionStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan) } }() @@ -1431,14 +1432,14 @@ func HandleOpenAIResponsesStreaming( } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan) return } response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan) } // Handle scanner errors first if err := scanner.Err(); err != nil { @@ -1814,11 +1815,11 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner if response.Usage != nil { response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) } // Handle scanner errors @@ -2086,11 +2087,11 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo if response.Usage != nil { response.ExtraFields.Latency = time.Since(startTime).Milliseconds() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan) return } - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response), responseChan) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response, nil), responseChan) } // Handle scanner errors @@ -2151,3 +2152,399 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR return nil } + +// ImageGeneration performs an Image Generation request to OpenAI's API. +// It formats the request, sends it to OpenAI, and processes the response. +// Returns a BifrostResponse containing the bifrost response or an error if the request fails. +func (provider *OpenAIProvider) ImageGeneration(ctx context.Context, key schemas.Key, + req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil { + return nil, err // Handle error + } + providerKey := provider.GetProviderKey() + openaiRequest := ToOpenAIImageGenerationRequest(req) + if openaiRequest == nil { + return nil, providerUtils.NewBifrostOperationError("invalid request: input is required", nil, providerKey) + } + providerName := provider.GetProviderKey() + + // Create request + httpReq := fasthttp.AcquireRequest() + httpResp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(httpReq) + defer fasthttp.ReleaseResponse(httpResp) + + // Set any extra headers from network config + providerUtils.SetExtraHeaders(ctx, httpReq, provider.networkConfig.ExtraHeaders, nil) + + httpReq.SetRequestURI(provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationRequest)) + httpReq.Header.SetMethod(http.MethodPost) + httpReq.Header.SetContentType("application/json") + if key.Value != "" { + httpReq.Header.Set("Authorization", "Bearer "+key.Value) + } + + // Serialize the request payload + jsonData, err := sonic.Marshal(openaiRequest) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("Error marshalling json for openai image generation", err, schemas.OpenAI) + } + + httpReq.SetBody(jsonData) + + // Make request + latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, httpReq, httpResp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if httpResp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(httpResp.Body()))) + return nil, ParseOpenAIError(httpResp, schemas.ImageGenerationRequest, providerName, openaiRequest.Model) + } + + // Create final response with the image data + openaiResponse := &OpenAIImageGenerationResponse{} + if bifrostErr := sonic.Unmarshal(httpResp.Body(), openaiResponse); bifrostErr != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, bifrostErr, providerName) + } + + bifrostResp := ToBifrostImageResponse(openaiResponse, openaiRequest.Model, latency) + + bifrostResp.ExtraFields.Provider = provider.GetProviderKey() + bifrostResp.ExtraFields.ModelRequested = openaiRequest.Model + bifrostResp.ExtraFields.RequestType = schemas.ImageGenerationRequest + + // Set raw request if enabled + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequest(&bifrostResp.ExtraFields, jsonData) + } + + // Set raw response if enabled + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResp.ExtraFields.RawResponse = string(httpResp.Body()) + } + + return bifrostResp, nil +} + +// ImageGenerationStream handles streaming for image generation. +// It formats the request body, creates HTTP request, and uses shared streaming logic. +// Returns a channel for streaming responses and any error that occurred. +func (provider *OpenAIProvider) ImageGenerationStream( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + key schemas.Key, + request *schemas.BifrostImageGenerationRequest, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + if request == nil { + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, provider.GetProviderKey()) + } + + // Check if image generation stream is allowed for this provider + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationStreamRequest); err != nil { + return nil, err + } + + var authHeader map[string]string + if key.Value != "" { + authHeader = map[string]string{"Authorization": "Bearer " + key.Value} + } + if !StreamingEnabledImageModels[request.Model] { + return nil, providerUtils.NewBifrostOperationError( + fmt.Sprintf("%s is not supported for streaming image generation", request.Model), + nil, + provider.GetProviderKey()) + } + + // Use shared streaming logic + return HandleOpenAIImageGenerationStreaming( + ctx, + provider.client, + provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationStreamRequest), + request, + authHeader, + provider.networkConfig.ExtraHeaders, + providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + postHookRunner, + nil, + nil, + nil, + provider.logger, + ) +} +func HandleOpenAIImageGenerationStreaming( + ctx context.Context, + client *fasthttp.Client, + url string, + request *schemas.BifrostImageGenerationRequest, + authHeader map[string]string, + extraHeaders map[string]string, + sendBackRawRequest bool, + sendBackRawResponse bool, + providerName schemas.ModelProvider, + postHookRunner schemas.PostHookRunner, + customRequestConverter func(*schemas.BifrostImageGenerationRequest) (any, error), + postRequestConverter func(*OpenAIImageGenerationRequest) *OpenAIImageGenerationRequest, + postResponseConverter func(*schemas.BifrostImageGenerationResponse) *schemas.BifrostImageGenerationResponse, + logger schemas.Logger, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + // Set headers + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + if authHeader != nil { + // Copy auth header to headers + maps.Copy(headers, authHeader) + } + + jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (any, error) { + if customRequestConverter != nil { + return customRequestConverter(request) + } + reqBody := ToOpenAIImageGenerationRequest(request) + if reqBody != nil { + reqBody.Stream = schemas.Ptr(true) + if postRequestConverter != nil { + reqBody = postRequestConverter(reqBody) + } + } + return reqBody, nil + }, + providerName) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create HTTP request for streaming + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + resp.StreamBody = true + defer fasthttp.ReleaseRequest(req) + + // Updating request + req.Header.SetMethod(http.MethodPost) + req.SetRequestURI(url) + req.Header.SetContentType("application/json") + + // Set any extra headers from network config + providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil) + + // Set headers + for key, value := range headers { + req.Header.Set(key, value) + } + + req.SetBody(jsonBody) + + // Make the request + err := client.Do(req, resp) + if err != nil { + defer providerUtils.ReleaseStreamingResponse(resp) + if errors.Is(err, context.Canceled) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: schemas.ErrRequestCancelled, + Error: err, + }, + } + } + if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName) + } + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + } + + // Check for HTTP errors + if resp.StatusCode() != fasthttp.StatusOK { + defer providerUtils.ReleaseStreamingResponse(resp) + return nil, ParseOpenAIError(resp, schemas.ImageGenerationStreamRequest, providerName, request.Model) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer providerUtils.ReleaseStreamingResponse(resp) + + scanner := bufio.NewScanner(resp.BodyStream()) + buf := make([]byte, 0, 1024*1024) + scanner.Buffer(buf, 10*1024*1024) + + chunkIndex := -1 + + startTime := time.Now() + lastChunkTime := startTime + + for scanner.Scan() { + // Check if context is done before processing + select { + case <-ctx.Done(): + return + default: + } + + line := scanner.Text() + + // Check for end of stream + if line == "" { + continue + } + + var jsonData string + + // Parse SSE data + if after, ok := strings.CutPrefix(line, "data: "); ok { + jsonData = after + } else { + // Handle raw JSON errors (without "data: " prefix) + jsonData = line + } + + // Skip empty data + if strings.TrimSpace(jsonData) == "" { + continue + } + + // First, check if this is an error response + var bifrostErr schemas.BifrostError + if err := sonic.Unmarshal([]byte(jsonData), &bifrostErr); err == nil { + if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + Provider: providerName, + ModelRequested: request.Model, + RequestType: schemas.ImageGenerationStreamRequest, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, logger) + return + } + } + + // Parse into bifrost response + var response *OpenAIImageStreamResponse + if err := sonic.Unmarshal([]byte(jsonData), &response); err != nil { + logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) + continue + } + + // TODO: Track Usage Correctly + // Handle final chunks (when stream_options include_usage is true) + if response.Type == ImageGenerationCompleted && response.Usage != nil { + // Collect usage information and send at the end of the stream + // Usage is contained within the completion message + } + + var chunkType string + switch response.Type { + case ImageGenerationCompleted: + chunkType = string(ImageGenerationCompleted) + case ImageGenerationPartial: + chunkType = string(ImageGenerationPartial) + } + + // Handle image data chunks + if response.B64JSON != nil { + b64Data := *response.B64JSON + chunkSize := ImageGenerationChunkSize + for offset := 0; offset < len(b64Data); offset += chunkSize { + end := offset + chunkSize + if end > len(b64Data) { + end = len(b64Data) + } + + chunkIndex++ + chunk := b64Data[offset:end] + + bifrostChunk := &schemas.BifrostImageGenerationStreamResponse{ + ID: uuid.NewString(), + Index: response.PartialImageIndex, + ChunkIndex: chunkIndex, + PartialB64: chunk, + Type: chunkType, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ImageGenerationStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + }, + } + + // Set raw response if enabled + if sendBackRawResponse { + bifrostChunk.ExtraFields.RawResponse = jsonData + } + + lastChunkTime = time.Now() + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, bifrostChunk), responseChan) + } + } + + // Handle completion chunk + if response.Type == ImageGenerationCompleted { + // Ensure completion chunk has a strictly increasing index. + chunkIndex++ + completionChunk := &schemas.BifrostImageGenerationStreamResponse{ + ID: uuid.NewString(), + Index: response.PartialImageIndex, + ChunkIndex: chunkIndex, + Type: string(ImageGenerationCompleted), + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ImageGenerationStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), + }, + } + if response.Usage != nil { + completionChunk.Usage = &schemas.ImageUsage{ + PromptTokens: response.Usage.InputTokens, + TotalTokens: response.Usage.TotalTokens, + } + } + + // Set raw request if enabled (only on last chunk) + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&completionChunk.ExtraFields, jsonBody) + } + + // Set raw response if enabled + if sendBackRawResponse { + completionChunk.ExtraFields.RawResponse = jsonData + } + + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, + providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, completionChunk), + responseChan) + return // Stop processing after completion + } + } + + // Handle scanner errors + if err := scanner.Err(); err != nil { + logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + } + }() + + return responseChan, nil +} diff --git a/core/providers/openai/openai_image_test.go b/core/providers/openai/openai_image_test.go new file mode 100644 index 000000000..e38a32dd9 --- /dev/null +++ b/core/providers/openai/openai_image_test.go @@ -0,0 +1,420 @@ +package openai_test + +import ( + "strings" + "testing" + "time" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/providers/openai" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestImageGenerationStreamingRequestConversion +func TestImageGenerationStreamingRequestConversion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + request *schemas.BifrostImageGenerationRequest + wantNil bool + validate func(t *testing.T, req *openai.OpenAIImageGenerationRequest) + }{ + { + name: "all parameters mapped", + request: &schemas.BifrostImageGenerationRequest{ + Provider: schemas.OpenAI, + Model: "dall-e-3", + Input: &schemas.ImageGenerationInput{Prompt: "test prompt"}, + Params: &schemas.ImageGenerationParameters{ + N: schemas.Ptr(3), + Size: schemas.Ptr("1024x1792"), + Quality: schemas.Ptr("hd"), + Style: schemas.Ptr("natural"), + ResponseFormat: schemas.Ptr("b64_json"), + User: schemas.Ptr("user-123"), + }, + }, + wantNil: false, + validate: func(t *testing.T, req *openai.OpenAIImageGenerationRequest) { + if req.Model != "dall-e-3" { + t.Errorf("Model mismatch") + } + if req.Prompt != "test prompt" { + t.Errorf("Prompt mismatch") + } + if *req.N != 3 { + t.Errorf("N mismatch: expected 3, got %d", *req.N) + } + if *req.Size != "1024x1792" { + t.Errorf("Size mismatch") + } + if *req.Quality != "hd" { + t.Errorf("Quality mismatch") + } + if *req.Style != "natural" { + t.Errorf("Style mismatch") + } + if *req.ResponseFormat != "b64_json" { + t.Errorf("ResponseFormat mismatch") + } + if *req.User != "user-123" { + t.Errorf("User mismatch") + } + }, + }, + { + name: "nil request returns nil", + request: nil, + wantNil: true, + }, + { + name: "nil input returns nil", + request: &schemas.BifrostImageGenerationRequest{ + Model: "dall-e-3", + Input: nil, + }, + wantNil: true, + }, + { + name: "nil params still works", + request: &schemas.BifrostImageGenerationRequest{ + Model: "dall-e-2", + Input: &schemas.ImageGenerationInput{Prompt: "minimal"}, + Params: nil, + }, + wantNil: false, + validate: func(t *testing.T, req *openai.OpenAIImageGenerationRequest) { + if req.N != nil || req.Size != nil { + t.Errorf("Optional params should be nil") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := openai.ToOpenAIImageGenerationRequest(tt.request) + if (got == nil) != tt.wantNil { + t.Errorf("ToOpenAIImageGenerationRequest() nil = %v, want %v", got == nil, tt.wantNil) + return + } + if !tt.wantNil && tt.validate != nil { + tt.validate(t, got) + } + }) + } +} + +// TestOpenAIRequestJSONOutput tests that OpenAI request serializes to correct JSON +func TestOpenAIRequestJSONOutput(t *testing.T) { + t.Parallel() + + req := openai.ToOpenAIImageGenerationRequest(&schemas.BifrostImageGenerationRequest{ + Model: "gpt-image-1", + Input: &schemas.ImageGenerationInput{Prompt: "a cat"}, + Params: &schemas.ImageGenerationParameters{ + Size: schemas.Ptr("1024x1024"), + Quality: schemas.Ptr("auto"), + }, + }) + + jsonBytes, err := sonic.Marshal(req) + if err != nil { + t.Fatalf("Serialization failed: %v", err) + } + jsonStr := string(jsonBytes) + + // Verify JSON structure matches OpenAI API + if !strings.Contains(jsonStr, `"model":"gpt-image-1"`) { + t.Errorf("JSON should contain model field") + } + if !strings.Contains(jsonStr, `"prompt":"a cat"`) { + t.Errorf("JSON should contain prompt field") + } + if !strings.Contains(jsonStr, `"size":"1024x1024"`) { + t.Errorf("JSON should contain size field") + } +} + +// ============================================================================= +// 3. RESPONSE TRANSFORMATION (OpenAI → Bifrost) +// ============================================================================= + +// TestToBifrostImageResponse tests OpenAI to Bifrost response conversion +func TestToBifrostImageResponse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + openai *openai.OpenAIImageGenerationResponse + model string + latency time.Duration + validate func(t *testing.T, resp *schemas.BifrostImageGenerationResponse) + }{ + { + name: "full response converts correctly", + openai: &openai.OpenAIImageGenerationResponse{ + Created: 1699999999, + Data: []struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + }{ + {URL: "https://example.com/1.png", RevisedPrompt: "revised prompt 1"}, + {B64JSON: "base64data", RevisedPrompt: "revised prompt 2"}, + }, + Usage: &openai.OpenAIImageGenerationUsage{ + InputTokens: 10, + TotalTokens: 50, + }, + }, + model: "dall-e-3", + latency: 500 * time.Millisecond, + validate: func(t *testing.T, resp *schemas.BifrostImageGenerationResponse) { + if resp.Created != 1699999999 { + t.Errorf("Created mismatch") + } + if len(resp.Data) != 2 { + t.Errorf("Expected 2 images, got %d", len(resp.Data)) + } + if resp.Data[0].URL != "https://example.com/1.png" { + t.Errorf("URL mismatch") + } + if resp.Data[0].Index != 0 { + t.Errorf("First image index should be 0") + } + if resp.Data[1].B64JSON != "base64data" { + t.Errorf("B64JSON mismatch") + } + if resp.Data[1].Index != 1 { + t.Errorf("Second image index should be 1") + } + if resp.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens should be mapped from InputTokens") + } + }, + }, + { + name: "nil response returns nil", + openai: nil, + validate: func(t *testing.T, resp *schemas.BifrostImageGenerationResponse) { + if resp != nil { + t.Errorf("Expected nil response") + } + }, + }, + { + name: "nil usage is preserved", + openai: &openai.OpenAIImageGenerationResponse{ + Created: 123, + Data: []struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + }{}, + Usage: nil, + }, + validate: func(t *testing.T, resp *schemas.BifrostImageGenerationResponse) { + if resp.Usage != nil { + t.Errorf("Usage should be nil") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resp := openai.ToBifrostImageResponse(tt.openai, tt.model, tt.latency) + tt.validate(t, resp) + }) + } +} + +// TestToBifrostImageGenerationRequest tests OpenAI to Bifrost request conversion +func TestToBifrostImageGenerationRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + request *openai.OpenAIImageGenerationRequest + validate func(t *testing.T, req *schemas.BifrostImageGenerationRequest) + }{ + { + name: "full request with all parameters converts correctly", + request: &openai.OpenAIImageGenerationRequest{ + Model: "openai/dall-e-3", + Prompt: "a beautiful sunset", + ImageGenerationParameters: schemas.ImageGenerationParameters{ + N: schemas.Ptr(2), + Size: schemas.Ptr("1024x1792"), + Quality: schemas.Ptr("hd"), + Style: schemas.Ptr("vivid"), + ResponseFormat: schemas.Ptr("b64_json"), + User: schemas.Ptr("user-123"), + }, + Fallbacks: []string{"azure/dall-e-3", "openai/dall-e-2"}, + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req == nil { + t.Fatal("Expected non-nil request") + } + if req.Provider != schemas.OpenAI { + t.Errorf("Provider mismatch: expected %s, got %s", schemas.OpenAI, req.Provider) + } + if req.Model != "dall-e-3" { + t.Errorf("Model mismatch: expected dall-e-3, got %s", req.Model) + } + if req.Input == nil { + t.Fatal("Expected non-nil Input") + } + if req.Input.Prompt != "a beautiful sunset" { + t.Errorf("Prompt mismatch: expected 'a beautiful sunset', got %s", req.Input.Prompt) + } + if req.Params == nil { + t.Fatal("Expected non-nil Params") + } + if *req.Params.N != 2 { + t.Errorf("N mismatch: expected 2, got %d", *req.Params.N) + } + if *req.Params.Size != "1024x1792" { + t.Errorf("Size mismatch: expected 1024x1792, got %s", *req.Params.Size) + } + if *req.Params.Quality != "hd" { + t.Errorf("Quality mismatch: expected hd, got %s", *req.Params.Quality) + } + if *req.Params.Style != "vivid" { + t.Errorf("Style mismatch: expected vivid, got %s", *req.Params.Style) + } + if *req.Params.ResponseFormat != "b64_json" { + t.Errorf("ResponseFormat mismatch: expected b64_json, got %s", *req.Params.ResponseFormat) + } + if *req.Params.User != "user-123" { + t.Errorf("User mismatch: expected user-123, got %s", *req.Params.User) + } + if len(req.Fallbacks) != 2 { + t.Errorf("Expected 2 fallbacks, got %d", len(req.Fallbacks)) + } + if req.Fallbacks[0].Provider != schemas.Azure || req.Fallbacks[0].Model != "dall-e-3" { + t.Errorf("First fallback mismatch: expected azure/dall-e-3, got %s/%s", req.Fallbacks[0].Provider, req.Fallbacks[0].Model) + } + if req.Fallbacks[1].Provider != schemas.OpenAI || req.Fallbacks[1].Model != "dall-e-2" { + t.Errorf("Second fallback mismatch: expected openai/dall-e-2, got %s/%s", req.Fallbacks[1].Provider, req.Fallbacks[1].Model) + } + }, + }, + { + name: "model without provider prefix defaults to OpenAI", + request: &openai.OpenAIImageGenerationRequest{ + Model: "dall-e-2", + Prompt: "minimal prompt", + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req.Provider != schemas.OpenAI { + t.Errorf("Provider should default to OpenAI, got %s", req.Provider) + } + if req.Model != "dall-e-2" { + t.Errorf("Model mismatch: expected dall-e-2, got %s", req.Model) + } + if req.Input.Prompt != "minimal prompt" { + t.Errorf("Prompt mismatch") + } + }, + }, + { + name: "request with nil params still works", + request: &openai.OpenAIImageGenerationRequest{ + Model: "gpt-image-1", + Prompt: "test prompt", + ImageGenerationParameters: schemas.ImageGenerationParameters{}, + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req.Params == nil { + t.Fatal("Expected non-nil Params even when empty") + } + if req.Params.N != nil { + t.Errorf("N should be nil when not set") + } + if req.Params.Size != nil { + t.Errorf("Size should be nil when not set") + } + }, + }, + { + name: "request with empty fallbacks", + request: &openai.OpenAIImageGenerationRequest{ + Model: "dall-e-3", + Prompt: "test", + Fallbacks: []string{}, + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if len(req.Fallbacks) != 0 { + t.Errorf("Expected empty fallbacks, got %d", len(req.Fallbacks)) + } + }, + }, + { + name: "request with nil fallbacks", + request: &openai.OpenAIImageGenerationRequest{ + Model: "dall-e-3", + Prompt: "test", + Fallbacks: nil, + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if len(req.Fallbacks) != 0 { + t.Errorf("Expected nil or empty fallbacks, got %d", len(req.Fallbacks)) + } + }, + }, + { + name: "request with partial parameters", + request: &openai.OpenAIImageGenerationRequest{ + Model: "dall-e-3", + Prompt: "partial params", + ImageGenerationParameters: schemas.ImageGenerationParameters{ + Size: schemas.Ptr("1024x1024"), + Quality: schemas.Ptr("standard"), + }, + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req.Params.Size == nil || *req.Params.Size != "1024x1024" { + t.Errorf("Size should be preserved") + } + if req.Params.Quality == nil || *req.Params.Quality != "standard" { + t.Errorf("Quality should be preserved") + } + if req.Params.N != nil { + t.Errorf("N should be nil when not set") + } + if req.Params.Style != nil { + t.Errorf("Style should be nil when not set") + } + }, + }, + { + name: "azure provider prefix in model", + request: &openai.OpenAIImageGenerationRequest{ + Model: "azure/dall-e-3", + Prompt: "azure model", + }, + validate: func(t *testing.T, req *schemas.BifrostImageGenerationRequest) { + if req.Provider != schemas.Azure { + t.Errorf("Provider should be Azure, got %s", req.Provider) + } + if req.Model != "dall-e-3" { + t.Errorf("Model should be dall-e-3, got %s", req.Model) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := tt.request.ToBifrostImageGenerationRequest() + tt.validate(t, req) + }) + } +} diff --git a/core/providers/openai/openai_test.go b/core/providers/openai/openai_test.go index 31a0fa02e..3d4be6952 100644 --- a/core/providers/openai/openai_test.go +++ b/core/providers/openai/openai_test.go @@ -38,6 +38,7 @@ func TestOpenAI(t *testing.T) { }, SpeechSynthesisModel: "gpt-4o-mini-tts", ReasoningModel: "o1", + ImageGenerationModel: "dall-e-3", Scenarios: testutil.TestScenarios{ TextCompletion: true, TextCompletionStream: true, @@ -60,6 +61,8 @@ func TestOpenAI(t *testing.T) { Embedding: true, Reasoning: true, ListModels: true, + ImageGeneration: true, + ImageGenerationStream: true, }, } diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 2cab76756..82fd10561 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -287,3 +287,50 @@ type OpenAIListModelsResponse struct { Object string `json:"object"` Data []OpenAIModel `json:"data"` } + +type ImageGenerationEventType string + +// OpenAIImageGenerationRequest is the struct for Image Generation requests by OpenAI. +type OpenAIImageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + + schemas.ImageGenerationParameters + + Stream *bool `json:"stream,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` +} + +// OpenAIImageGenerationResponse is the struct for Image Generation responses by OpenAI. +type OpenAIImageGenerationResponse struct { + Created int64 `json:"created"` + Data []struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + } `json:"data"` + Background *string `json:"background,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + Size *string `json:"size,omitempty"` + Quality *string `json:"quality,omitempty"` + Usage *OpenAIImageGenerationUsage `json:"usage"` +} + +type OpenAIImageGenerationUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + + InputTokensDetails *struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` + } `json:"input_tokens_details,omitempty"` +} + +// OpenAIImageStreamResponse is the struct for Image Generation streaming responses by OpenAI. +type OpenAIImageStreamResponse struct { + Type ImageGenerationEventType `json:"type,omitempty"` + B64JSON *string `json:"b64_json,omitempty"` + PartialImageIndex int `json:"partial_image_index,omitempty"` + Usage *OpenAIImageGenerationUsage `json:"usage,omitempty"` +} diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 1271e26cc..d263d60c7 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -281,3 +281,13 @@ func (provider *OpenRouterProvider) Transcription(ctx context.Context, key schem func (provider *OpenRouterProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the OpenRouter provider. +func (provider *OpenRouterProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the OpenRouter provider. +func (provider *OpenRouterProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index df7764ec2..a20c7b19c 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -181,3 +181,13 @@ func (provider *ParasailProvider) Transcription(ctx context.Context, key schemas func (provider *ParasailProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Parasail provider. +func (provider *ParasailProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Parasail provider. +func (provider *ParasailProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index e77878e49..f700f7e8d 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -251,3 +251,13 @@ func (provider *PerplexityProvider) Transcription(ctx context.Context, key schem func (provider *PerplexityProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the Perplexity provider. +func (provider *PerplexityProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Perplexity provider. +func (provider *PerplexityProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index af33fc720..a0b024462 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -219,3 +219,13 @@ func (provider *SGLProvider) Transcription(ctx context.Context, key schemas.Key, func (provider *SGLProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } + +// ImageGeneration is not supported by the SGL provider. +func (provider *SGLProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the SGL provider. +func (provider *SGLProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index ba86f5cb6..ac83f9374 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -629,6 +629,7 @@ func ProcessAndSendResponse( streamResponse.BifrostResponsesStreamResponse = processedResponse.ResponsesStreamResponse streamResponse.BifrostSpeechStreamResponse = processedResponse.SpeechStreamResponse streamResponse.BifrostTranscriptionStreamResponse = processedResponse.TranscriptionStreamResponse + streamResponse.BifrostImageGenerationStreamResponse = processedResponse.ImageGenerationStreamResponse } if processedError != nil { streamResponse.BifrostError = processedError @@ -857,6 +858,7 @@ func GetBifrostResponseForStreamResponse( responsesStreamResponse *schemas.BifrostResponsesStreamResponse, speechStreamResponse *schemas.BifrostSpeechStreamResponse, transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse, + imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse, ) *schemas.BifrostResponse { //TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{} @@ -877,6 +879,9 @@ func GetBifrostResponseForStreamResponse( case transcriptionStreamResponse != nil: bifrostResponse.TranscriptionStreamResponse = transcriptionStreamResponse return bifrostResponse + case imageGenerationStreamResponse != nil: + bifrostResponse.ImageGenerationStreamResponse = imageGenerationStreamResponse + return bifrostResponse } return nil } diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 569feb49f..c2c745c4d 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -1406,6 +1406,16 @@ func (provider *VertexProvider) TranscriptionStream(ctx context.Context, postHoo return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } +// ImageGeneration is not supported by the Vertex provider. +func (provider *VertexProvider) ImageGeneration(ctx context.Context, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +// ImageGenerationStream is not supported by the Vertex provider. +func (provider *VertexProvider) ImageGenerationStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} + // stripVertexGeminiUnsupportedFields removes fields that are not supported by Vertex AI's Gemini API. // Specifically, it removes the "id" field from function_call and function_response objects in contents. func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequest) { @@ -1415,7 +1425,6 @@ func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequ if part.FunctionCall != nil { part.FunctionCall.ID = "" } - // Remove id from function_response if part.FunctionResponse != nil { part.FunctionResponse.ID = "" diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 7b587480a..2dc23ad9b 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -83,18 +83,20 @@ var StandardProviders = []ModelProvider{ type RequestType string const ( - ListModelsRequest RequestType = "list_models" - TextCompletionRequest RequestType = "text_completion" - TextCompletionStreamRequest RequestType = "text_completion_stream" - ChatCompletionRequest RequestType = "chat_completion" - ChatCompletionStreamRequest RequestType = "chat_completion_stream" - ResponsesRequest RequestType = "responses" - ResponsesStreamRequest RequestType = "responses_stream" - EmbeddingRequest RequestType = "embedding" - SpeechRequest RequestType = "speech" - SpeechStreamRequest RequestType = "speech_stream" - TranscriptionRequest RequestType = "transcription" - TranscriptionStreamRequest RequestType = "transcription_stream" + ListModelsRequest RequestType = "list_models" + TextCompletionRequest RequestType = "text_completion" + TextCompletionStreamRequest RequestType = "text_completion_stream" + ChatCompletionRequest RequestType = "chat_completion" + ChatCompletionStreamRequest RequestType = "chat_completion_stream" + ResponsesRequest RequestType = "responses" + ResponsesStreamRequest RequestType = "responses_stream" + EmbeddingRequest RequestType = "embedding" + SpeechRequest RequestType = "speech" + SpeechStreamRequest RequestType = "speech_stream" + TranscriptionRequest RequestType = "transcription" + TranscriptionStreamRequest RequestType = "transcription_stream" + ImageGenerationRequest RequestType = "image_generation" + ImageGenerationStreamRequest RequestType = "image_generation_stream" ) // BifrostContextKey is a type for context keys used in Bifrost. @@ -143,17 +145,19 @@ type Fallback struct { // - EmbeddingRequest // - SpeechRequest // - TranscriptionRequest +// - ImageGenerationRequest // NOTE: Bifrost Request is submitted back to pool after every use so DO NOT keep references to this struct after use, especially in go routines. type BifrostRequest struct { RequestType RequestType - ListModelsRequest *BifrostListModelsRequest - TextCompletionRequest *BifrostTextCompletionRequest - ChatRequest *BifrostChatRequest - ResponsesRequest *BifrostResponsesRequest - EmbeddingRequest *BifrostEmbeddingRequest - SpeechRequest *BifrostSpeechRequest - TranscriptionRequest *BifrostTranscriptionRequest + ListModelsRequest *BifrostListModelsRequest + TextCompletionRequest *BifrostTextCompletionRequest + ChatRequest *BifrostChatRequest + ResponsesRequest *BifrostResponsesRequest + EmbeddingRequest *BifrostEmbeddingRequest + SpeechRequest *BifrostSpeechRequest + TranscriptionRequest *BifrostTranscriptionRequest + ImageGenerationRequest *BifrostImageGenerationRequest } // GetRequestFields returns the provider, model, and fallbacks from the request. @@ -171,6 +175,8 @@ func (br *BifrostRequest) GetRequestFields() (provider ModelProvider, model stri return br.SpeechRequest.Provider, br.SpeechRequest.Model, br.SpeechRequest.Fallbacks case br.TranscriptionRequest != nil: return br.TranscriptionRequest.Provider, br.TranscriptionRequest.Model, br.TranscriptionRequest.Fallbacks + case br.ImageGenerationRequest != nil: + return br.ImageGenerationRequest.Provider, br.ImageGenerationRequest.Model, br.ImageGenerationRequest.Fallbacks } return "", "", nil @@ -190,6 +196,8 @@ func (br *BifrostRequest) SetProvider(provider ModelProvider) { br.SpeechRequest.Provider = provider case br.TranscriptionRequest != nil: br.TranscriptionRequest.Provider = provider + case br.ImageGenerationRequest != nil: + br.ImageGenerationRequest.Provider = provider } } @@ -207,6 +215,8 @@ func (br *BifrostRequest) SetModel(model string) { br.SpeechRequest.Model = model case br.TranscriptionRequest != nil: br.TranscriptionRequest.Model = model + case br.ImageGenerationRequest != nil: + br.ImageGenerationRequest.Model = model } } @@ -224,6 +234,8 @@ func (br *BifrostRequest) SetFallbacks(fallbacks []Fallback) { br.SpeechRequest.Fallbacks = fallbacks case br.TranscriptionRequest != nil: br.TranscriptionRequest.Fallbacks = fallbacks + case br.ImageGenerationRequest != nil: + br.ImageGenerationRequest.Fallbacks = fallbacks } } @@ -241,6 +253,8 @@ func (br *BifrostRequest) SetRawRequestBody(rawRequestBody []byte) { br.SpeechRequest.RawRequestBody = rawRequestBody case br.TranscriptionRequest != nil: br.TranscriptionRequest.RawRequestBody = rawRequestBody + case br.ImageGenerationRequest != nil: + br.ImageGenerationRequest.RawRequestBody = rawRequestBody } } @@ -248,15 +262,17 @@ func (br *BifrostRequest) SetRawRequestBody(rawRequestBody []byte) { // BifrostResponse represents the complete result from any bifrost request. type BifrostResponse struct { - TextCompletionResponse *BifrostTextCompletionResponse - ChatResponse *BifrostChatResponse - ResponsesResponse *BifrostResponsesResponse - ResponsesStreamResponse *BifrostResponsesStreamResponse - EmbeddingResponse *BifrostEmbeddingResponse - SpeechResponse *BifrostSpeechResponse - SpeechStreamResponse *BifrostSpeechStreamResponse - TranscriptionResponse *BifrostTranscriptionResponse - TranscriptionStreamResponse *BifrostTranscriptionStreamResponse + TextCompletionResponse *BifrostTextCompletionResponse + ChatResponse *BifrostChatResponse + ResponsesResponse *BifrostResponsesResponse + ResponsesStreamResponse *BifrostResponsesStreamResponse + EmbeddingResponse *BifrostEmbeddingResponse + SpeechResponse *BifrostSpeechResponse + SpeechStreamResponse *BifrostSpeechStreamResponse + TranscriptionResponse *BifrostTranscriptionResponse + TranscriptionStreamResponse *BifrostTranscriptionStreamResponse + ImageGenerationResponse *BifrostImageGenerationResponse + ImageGenerationStreamResponse *BifrostImageGenerationStreamResponse } func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { @@ -279,6 +295,10 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { return &r.TranscriptionResponse.ExtraFields case r.TranscriptionStreamResponse != nil: return &r.TranscriptionStreamResponse.ExtraFields + case r.ImageGenerationResponse != nil: + return &r.ImageGenerationResponse.ExtraFields + case r.ImageGenerationStreamResponse != nil: + return &r.ImageGenerationStreamResponse.ExtraFields } return &BifrostResponseExtraFields{} @@ -326,6 +346,7 @@ type BifrostStream struct { *BifrostResponsesStreamResponse *BifrostSpeechStreamResponse *BifrostTranscriptionStreamResponse + *BifrostImageGenerationStreamResponse *BifrostError } @@ -342,6 +363,8 @@ func (bs BifrostStream) MarshalJSON() ([]byte, error) { return sonic.Marshal(bs.BifrostSpeechStreamResponse) } else if bs.BifrostTranscriptionStreamResponse != nil { return sonic.Marshal(bs.BifrostTranscriptionStreamResponse) + } else if bs.BifrostImageGenerationStreamResponse != nil { + return sonic.Marshal(bs.BifrostImageGenerationStreamResponse) } else if bs.BifrostError != nil { return sonic.Marshal(bs.BifrostError) } diff --git a/core/schemas/images.go b/core/schemas/images.go new file mode 100644 index 000000000..3e1218ba0 --- /dev/null +++ b/core/schemas/images.go @@ -0,0 +1,65 @@ +package schemas + +// BifrostImageGenerationRequest represents an image generation request in bifrost format +type BifrostImageGenerationRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input *ImageGenerationInput `json:"input"` + Params *ImageGenerationParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` + RawRequestBody []byte `json:"-"` +} + +// GetRawRequestBody implements utils.RequestBodyGetter. +func (b *BifrostImageGenerationRequest) GetRawRequestBody() []byte { + return b.RawRequestBody +} + +type ImageGenerationInput struct { + Prompt string `json:"prompt"` +} + +type ImageGenerationParameters struct { + N *int `json:"n,omitempty"` // Number of images (1-10) + Size *string `json:"size,omitempty"` // "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792", "1536x1024", "1024x1536", "auto" + Quality *string `json:"quality,omitempty"` // "auto", "high", "medium", "low" + Style *string `json:"style,omitempty"` // "natural", "vivid" + ResponseFormat *string `json:"response_format,omitempty"` // "url", "b64_json" + User *string `json:"user,omitempty"` + ExtraParams map[string]interface{} `json:"extra_params,omitempty"` +} + +// BifrostImageGenerationResponse represents the image generation response in bifrost format +type BifrostImageGenerationResponse struct { + ID string `json:"id"` + Created int64 `json:"created"` + Model string `json:"model"` + Data []ImageData `json:"data"` + Usage *ImageUsage `json:"usage,omitempty"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields,omitempty"` +} + +type ImageData struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + Index int `json:"index"` +} + +type ImageUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Streaming Response +type BifrostImageGenerationStreamResponse struct { + ID string `json:"id"` + Type string `json:"type"` // "image_generation.partial_image", "image_generation.completed", "error" + Index int `json:"index"` // Which image (0-N) + ChunkIndex int `json:"chunk_index"` // Chunk order within image + PartialB64 string `json:"partial_b64,omitempty"` // Base64 chunk + RevisedPrompt string `json:"revised_prompt,omitempty"` // On first chunk + Usage *ImageUsage `json:"usage,omitempty"` // On final chunk + Error *BifrostError `json:"error,omitempty"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` +} diff --git a/core/schemas/provider.go b/core/schemas/provider.go index da0a8e51d..b0127e080 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -306,4 +306,10 @@ type Provider interface { Transcription(ctx context.Context, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError) // TranscriptionStream performs a transcription stream request TranscriptionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostTranscriptionRequest) (chan *BifrostStream, *BifrostError) + // ImageGeneration performs a image generation request + ImageGeneration(ctx context.Context, key Key, request *BifrostImageGenerationRequest) ( + *BifrostImageGenerationResponse, *BifrostError) + // ImageGenerationStream performs a image generation stream request + ImageGenerationStream(ctx context.Context, postHookRunner PostHookRunner, key Key, + request *BifrostImageGenerationRequest) (chan *BifrostStream, *BifrostError) } diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 5b974fcaa..91d416dfd 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -536,9 +536,17 @@ type ResponsesToolMessageOutputStruct struct { ResponsesToolCallOutputStr *string // Common output string for tool calls and outputs (used by function, custom and local shell tool calls) ResponsesFunctionToolCallOutputBlocks []ResponsesMessageContentBlock ResponsesComputerToolCallOutput *ResponsesComputerToolCallOutputData + ResponsesImageGenerationCallOutput *ResponsesImageGenerationCallOutput +} + +type ResponsesImageGenerationCallOutput struct { + Result string `json:"result"` // JSON string with image data } func (output ResponsesToolMessageOutputStruct) MarshalJSON() ([]byte, error) { + if output.ResponsesImageGenerationCallOutput != nil { + return sonic.Marshal(output.ResponsesImageGenerationCallOutput) + } if output.ResponsesToolCallOutputStr != nil { return sonic.Marshal(*output.ResponsesToolCallOutputStr) } @@ -548,7 +556,7 @@ func (output ResponsesToolMessageOutputStruct) MarshalJSON() ([]byte, error) { if output.ResponsesComputerToolCallOutput != nil { return sonic.Marshal(output.ResponsesComputerToolCallOutput) } - return nil, fmt.Errorf("responses tool message output struct is neither a string nor an array of responses message content blocks nor a computer tool call output data") + return nil, fmt.Errorf("responses tool message output struct is neither a string nor an array of responses message content blocks nor a computer tool call output data nor an image generation call output") } func (output *ResponsesToolMessageOutputStruct) UnmarshalJSON(data []byte) error { var str string @@ -561,12 +569,25 @@ func (output *ResponsesToolMessageOutputStruct) UnmarshalJSON(data []byte) error output.ResponsesFunctionToolCallOutputBlocks = array return nil } + + // Peek at the object to distinguish image-generation vs computer tool outputs. + var raw map[string]interface{} + if err := sonic.Unmarshal(data, &raw); err == nil { + if _, hasResult := raw["result"]; hasResult { + var imageGenerationCallOutput ResponsesImageGenerationCallOutput + if err := sonic.Unmarshal(data, &imageGenerationCallOutput); err == nil { + output.ResponsesImageGenerationCallOutput = &imageGenerationCallOutput + return nil + } + } + } + var computerToolCallOutput ResponsesComputerToolCallOutputData if err := sonic.Unmarshal(data, &computerToolCallOutput); err == nil { output.ResponsesComputerToolCallOutput = &computerToolCallOutput return nil } - return fmt.Errorf("responses tool message output struct is neither a string nor an array of responses message content blocks nor a computer tool call output data") + return fmt.Errorf("responses tool message output struct is neither a string nor an array of responses message content blocks nor a computer tool call output data nor an image generation call output") } // ============================================================================= diff --git a/core/utils.go b/core/utils.go index 340ca1b9b..405c11834 100644 --- a/core/utils.go +++ b/core/utils.go @@ -191,7 +191,7 @@ func IsStandardProvider(providerKey schemas.ModelProvider) bool { // IsStreamRequestType returns true if the given request type is a stream request. func IsStreamRequestType(reqType schemas.RequestType) bool { - return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest + return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.ImageGenerationStreamRequest } // IsFinalChunk returns true if the given context is a final chunk. diff --git a/docs/features/unified-interface.mdx b/docs/features/unified-interface.mdx index 77095dc9c..bb659dec7 100644 --- a/docs/features/unified-interface.mdx +++ b/docs/features/unified-interface.mdx @@ -85,24 +85,24 @@ response, err := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ The following table summarizes which operations are supported by each provider via Bifrost’s unified interface. -| Provider | Models | Text | Text (stream) | Chat | Chat (stream) | Responses | Responses (stream) | Embeddings | TTS | TTS (stream) | STT | STT (stream) | -|----------|--------|------|----------------|------|---------------|-----------|--------------------|------------|-----|-------------|-----|--------------| -| Anthropic (`anthropic/`) | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| Azure (`azure/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| Bedrock (`bedrock/`) | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| Cerebras (`cerebras/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| Cohere (`cohere/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| Elevenlabs (`elevenlabs/`) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | -| Gemini (`gemini/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Groq (`groq/`) | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| Mistral (`mistral/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| Ollama (`ollama/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| OpenAI (`openai/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| OpenRouter (`openrouter/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| Parasail (`parasail/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| Perplexity (`perplexity/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| SGL (`sgl/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| Vertex AI (`vertex/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | +| Provider | Models | Text | Text (stream) | Chat | Chat (stream) | Responses | Responses (stream) | Embeddings | TTS | TTS (stream) | STT | STT (stream) | Image Generation | Image Generation (stream) | +|----------|--------|------|----------------|------|---------------|-----------|--------------------|------------|-----|-------------|-----|--------------|------------------|---------------------------| +| Anthropic (`anthropic/`) | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Azure (`azure/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | +| Bedrock (`bedrock/`) | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Cerebras (`cerebras/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Cohere (`cohere/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Elevenlabs (`elevenlabs/`) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| Gemini (`gemini/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| Groq (`groq/`) | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Mistral (`mistral/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Ollama (`ollama/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| OpenAI (`openai/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| OpenRouter (`openrouter/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Parasail (`parasail/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Perplexity (`perplexity/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| SGL (`sgl/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Vertex AI (`vertex/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | - 🟡 Not supported by the downstream provider, but internally implemented by Bifrost as a fallback. - ❌ Not supported by the downstream provider, hence not supported by Bifrost. diff --git a/docs/quickstart/gateway/multimodal.mdx b/docs/quickstart/gateway/multimodal.mdx index a42bbe0c0..414803f9e 100644 --- a/docs/quickstart/gateway/multimodal.mdx +++ b/docs/quickstart/gateway/multimodal.mdx @@ -1,6 +1,6 @@ --- title: "Multimodal Support" -description: "Process multiple types of content including images, audio, and text with AI models. Bifrost supports vision analysis, speech synthesis, and audio transcription across various providers." +description: "Process multiple types of content including images, audio, and text with AI models. Bifrost supports vision analysis, image generation, speech synthesis, and audio transcription across various providers." icon: "images" --- @@ -46,6 +46,43 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ } ``` +## Image Generation: Generating Images with AI + +Generate images from text prompts using OpenAI-compatible image generation models. + +### Basic Image Generation + +Generate an image from a text prompt using `dall-e-3`. + +```bash +curl --location 'http://localhost:8080/v1/images/generations' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/dall-e-3", + "prompt": "A futuristic city skyline at sunset with flying cars", + "size": "1024x1024", + "response_format": "url" +}' +``` +**Response format:** +```json +{ + "id": "img-123", + "created": 1713833628, + "model": "dall-e-3", + "data": [ + { + "url": "https://oaidalleapiprodscus.blob.core.windows.net/...", + "revised_prompt": "A futuristic city skyline at sunset featuring advanced architecture and flying vehicles.", + "index": 0 + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 10 + } +} +``` ## Audio Understanding: Analyzing Audio with AI If your chat application supports text input, you can add audio input and output—just include audio in the modalities array and use an audio model, like gpt-4o-audio-preview. diff --git a/docs/quickstart/go-sdk/multimodal.mdx b/docs/quickstart/go-sdk/multimodal.mdx index 1a0a825cb..58c31cad1 100644 --- a/docs/quickstart/go-sdk/multimodal.mdx +++ b/docs/quickstart/go-sdk/multimodal.mdx @@ -1,6 +1,6 @@ --- title: "Multimodal Support" -description: "Process multiple types of content including images, audio, and text with AI models. Bifrost supports vision analysis, speech synthesis, and audio transcription across various providers." +description: "Process multiple types of content including images, audio, and text with AI models. Bifrost supports vision analysis, image generation, speech synthesis, and audio transcription across various providers." icon: "images" --- @@ -41,6 +41,53 @@ if err != nil { fmt.Println("Response:", *response.Choices[0].Message.Content.ContentStr) ``` +## Image Generation: Generating Images with AI + +Generate images from text prompts using OpenAI-compatible image generation models via the Go SDK. + +```go +response, err := client.ImageGenerationRequest(context.Background(), &schemas.BifrostImageGenerationRequest{ + Provider: schemas.OpenAI, + Model: "dall-e-3", + Input: &schemas.ImageGenerationInput{ + Prompt: "A futuristic city skyline at sunset with flying cars", + }, + Params: &schemas.ImageGenerationParameters{ + Size: schemas.Ptr("1024x1024"), + ResponseFormat: schemas.Ptr("url"), + }, +}) + +if err != nil { + panic(err) +} + +// Handle image generation response +if len(response.Data) > 0 { + imageData := response.Data[0] + + // Handle URL response (when response_format is "url") + if imageData.URL != "" { + fmt.Printf("Generated image URL: %s\n", imageData.URL) + } + + // Handle base64-encoded response (when response_format is "b64_json") + if imageData.B64JSON != "" { + fmt.Printf("Generated base64 image (length: %d)\n", len(imageData.B64JSON)) + } + + // Handle revised prompt if present + if imageData.RevisedPrompt != "" { + fmt.Printf("Revised prompt: %s\n", imageData.RevisedPrompt) + } +} + +// Handle usage metrics +if response.Usage != nil { + fmt.Printf("Usage: %d tokens\n", response.Usage.TotalTokens) +} +``` + ## Audio Understanding: Analyzing Audio with AI If your chat application supports text input, you can add audio input and output—just include audio in the modalities array and use an audio model, like gpt-4o-audio-preview. diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb..ec82318fd 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1 @@ +feat: add image generation streaming accumulation support diff --git a/framework/streaming/accumulator.go b/framework/streaming/accumulator.go index b19bfde0c..480e4f980 100644 --- a/framework/streaming/accumulator.go +++ b/framework/streaming/accumulator.go @@ -20,6 +20,7 @@ type Accumulator struct { responsesStreamChunkPool sync.Pool // Pool for reusing ResponsesStreamChunk structs audioStreamChunkPool sync.Pool // Pool for reusing AudioStreamChunk structs transcriptionStreamChunkPool sync.Pool // Pool for reusing TranscriptionStreamChunk structs + imageStreamChunkPool sync.Pool // Pool for reusing ImageStreamChunk structs pricingManager *modelcatalog.ModelCatalog @@ -101,12 +102,32 @@ func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) { a.responsesStreamChunkPool.Put(chunk) } +// getImageStreamChunk gets an image stream chunk from the pool +func (a *Accumulator) getImageStreamChunk() *ImageStreamChunk { + return a.imageStreamChunkPool.Get().(*ImageStreamChunk) +} + +// putImageStreamChunk returns an image stream chunk to the pool +func (a *Accumulator) putImageStreamChunk(chunk *ImageStreamChunk) { + chunk.Timestamp = time.Time{} + chunk.Delta = nil + chunk.FinishReason = nil + chunk.ErrorDetails = nil + chunk.ChunkIndex = 0 + chunk.ImageIndex = 0 + chunk.Cost = nil + chunk.SemanticCacheDebug = nil + chunk.TokenUsage = nil + a.imageStreamChunkPool.Put(chunk) +} + // CreateStreamAccumulator creates a new stream accumulator for a request func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulator { sc := &StreamAccumulator{ RequestID: requestID, ChatStreamChunks: make([]*ChatStreamChunk, 0), ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0), + ImageStreamChunks: make([]*ImageStreamChunk, 0), IsComplete: false, Timestamp: time.Now(), } @@ -163,7 +184,7 @@ func (a *Accumulator) addTranscriptionStreamChunk(requestID string, chunk *Trans return nil } -// AddAudioStreamChunk adds an audio stream chunk to the stream accumulator +// addAudioStreamChunk adds an audio stream chunk to the stream accumulator func (a *Accumulator) addAudioStreamChunk(requestID string, chunk *AudioStreamChunk, isFinalChunk bool) error { accumulator := a.getOrCreateStreamAccumulator(requestID) // Lock the accumulator @@ -203,6 +224,24 @@ func (a *Accumulator) addResponsesStreamChunk(requestID string, chunk *Responses return nil } +// addImageStreamChunk adds an image stream chunk to the stream accumulator +func (a *Accumulator) addImageStreamChunk(requestID string, chunk *ImageStreamChunk, isFinalChunk bool) error { + acc := a.getOrCreateStreamAccumulator(requestID) + acc.mu.Lock() + defer acc.mu.Unlock() + + if acc.StartTimestamp.IsZero() { + acc.StartTimestamp = chunk.Timestamp + } + // Add chunk to the list (chunks arrive in order) + acc.ImageStreamChunks = append(acc.ImageStreamChunks, chunk) + + if isFinalChunk { + acc.FinalTimestamp = chunk.Timestamp + } + return nil +} + // cleanupStreamAccumulator removes the stream accumulator for a request. // IMPORTANT: Caller must hold accumulator.mu lock before calling this function // to prevent races when returning chunks to pools. @@ -223,6 +262,9 @@ func (a *Accumulator) cleanupStreamAccumulator(requestID string) { for _, chunk := range acc.TranscriptionStreamChunks { a.putTranscriptionStreamChunk(chunk) } + for _, chunk := range acc.ImageStreamChunks { + a.putImageStreamChunk(chunk) + } a.streamAccumulators.Delete(requestID) } } @@ -313,6 +355,8 @@ func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, resu isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest isChatStreaming := requestType == schemas.ChatCompletionStreamRequest || requestType == schemas.TextCompletionStreamRequest isResponsesStreaming := requestType == schemas.ResponsesStreamRequest + // Edit images/ Image variation requests will be added here + isImageStreaming := requestType == schemas.ImageGenerationStreamRequest if isChatStreaming { // Handle text-based streaming with ordered accumulation @@ -328,6 +372,9 @@ func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, resu } else if isResponsesStreaming { // Handle responses streaming with responses accumulation return a.processResponsesStreamingResponse(ctx, result, bifrostErr) + } else if isImageStreaming { + // Handle image streaming + return a.processImageStreamingResponse(ctx, result, bifrostErr) } return nil, fmt.Errorf("request type missing/invalid for accumulator: %s", requestType) } @@ -352,6 +399,9 @@ func (a *Accumulator) Cleanup() { for _, chunk := range accumulator.AudioStreamChunks { a.audioStreamChunkPool.Put(chunk) } + for _, chunk := range accumulator.ImageStreamChunks { + a.imageStreamChunkPool.Put(chunk) + } accumulator.mu.Unlock() a.streamAccumulators.Delete(key) @@ -441,6 +491,11 @@ func NewAccumulator(pricingManager *modelcatalog.ModelCatalog, logger schemas.Lo return &TranscriptionStreamChunk{} }, }, + imageStreamChunkPool: sync.Pool{ + New: func() any { + return &ImageStreamChunk{} + }, + }, pricingManager: pricingManager, logger: logger, ttl: 30 * time.Minute, @@ -455,6 +510,7 @@ func NewAccumulator(pricingManager *modelcatalog.ModelCatalog, logger schemas.Lo a.responsesStreamChunkPool.Put(&ResponsesStreamChunk{}) a.audioStreamChunkPool.Put(&AudioStreamChunk{}) a.transcriptionStreamChunkPool.Put(&TranscriptionStreamChunk{}) + a.imageStreamChunkPool.Put(&ImageStreamChunk{}) } go a.startAccumulatorMapCleanup() return a diff --git a/framework/streaming/images.go b/framework/streaming/images.go new file mode 100644 index 000000000..d95627388 --- /dev/null +++ b/framework/streaming/images.go @@ -0,0 +1,278 @@ +package streaming + +import ( + "fmt" + "sort" + "strings" + "time" + + bifrost "github.com/maximhq/bifrost/core" + schemas "github.com/maximhq/bifrost/core/schemas" +) + +// buildCompleteImageFromImageStreamChunks builds a complete image generation response from accumulated chunks +func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStreamChunk) *schemas.BifrostImageGenerationResponse { + // Sort chunks by ImageIndex, then ChunkIndex + sort.Slice(chunks, func(i, j int) bool { + if chunks[i].ImageIndex != chunks[j].ImageIndex { + return chunks[i].ImageIndex < chunks[j].ImageIndex + } + return chunks[i].ChunkIndex < chunks[j].ChunkIndex + }) + + // Reconstruct complete images from chunks + images := make(map[int]*strings.Builder) + var model string + var revisedPrompts map[int]string = make(map[int]string) + + for _, chunk := range chunks { + if chunk.Delta == nil { + continue + } + + // Extract metadata + if model == "" && chunk.Delta.ExtraFields.ModelRequested != "" { + model = chunk.Delta.ExtraFields.ModelRequested + } + + // Store revised prompt if present (usually in first chunk) + if chunk.Delta.RevisedPrompt != "" { + revisedPrompts[chunk.ImageIndex] = chunk.Delta.RevisedPrompt + } + + // Reconstruct base64 for each image + if chunk.Delta.PartialB64 != "" { + if _, ok := images[chunk.ImageIndex]; !ok { + images[chunk.ImageIndex] = &strings.Builder{} + } + images[chunk.ImageIndex].WriteString(chunk.Delta.PartialB64) + } + } + + if len(images) == 0 { + return nil + } + // Build ImageData array in deterministic manner (if indexes are not in order) + imageIndexes := make([]int, 0, len(images)) + for idx := range images { + imageIndexes = append(imageIndexes, idx) + } + sort.Ints(imageIndexes) + + imageData := make([]schemas.ImageData, 0, len(images)) + for _, imageIndex := range imageIndexes { + builder := images[imageIndex] + if builder == nil { + continue + } + imageData = append(imageData, schemas.ImageData{ + B64JSON: builder.String(), + Index: imageIndex, + RevisedPrompt: revisedPrompts[imageIndex], + }) + } + + // Build final response + var responseID string + for _, chunk := range chunks { + if chunk.Delta != nil && chunk.Delta.ID != "" { + responseID = chunk.Delta.ID + break + } + } + + finalResponse := &schemas.BifrostImageGenerationResponse{ + ID: responseID, + Created: time.Now().Unix(), + Model: model, + Data: imageData, + } + + return finalResponse +} + +// processAccumulatedImageStreamingChunks processes all accumulated image chunks in order +func (a *Accumulator) processAccumulatedImageStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) { + acc := a.getOrCreateStreamAccumulator(requestID) + // Lock the accumulator + acc.mu.Lock() + defer func() { + if isFinalChunk { + // Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool + a.cleanupStreamAccumulator(requestID) + } + acc.mu.Unlock() + }() + + // Initialize accumulated data + data := &AccumulatedData{ + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: acc.StartTimestamp, + EndTimestamp: acc.FinalTimestamp, + Latency: 0, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, + } + + // Build complete message from accumulated chunks + completeImage := a.buildCompleteImageFromImageStreamChunks(acc.ImageStreamChunks) + if !isFinalChunk { + data.ImageGenerationOutput = completeImage + return data, nil + } + + // Update database with complete message + data.Status = "success" + if bifrostErr != nil { + data.Status = "error" + } + if acc.StartTimestamp.IsZero() || acc.FinalTimestamp.IsZero() { + data.Latency = 0 + } else { + data.Latency = acc.FinalTimestamp.Sub(acc.StartTimestamp).Nanoseconds() / 1e6 + } + data.EndTimestamp = acc.FinalTimestamp + data.ImageGenerationOutput = completeImage + data.ErrorDetails = bifrostErr + + // Update token usage from final chunk if available + if len(acc.ImageStreamChunks) > 0 { + lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1] + if lastChunk.Delta != nil && lastChunk.Delta.Usage != nil { + data.TokenUsage = &schemas.BifrostLLMUsage{ + PromptTokens: lastChunk.Delta.Usage.PromptTokens, + CompletionTokens: 0, // Image generation doesn't have completion tokens + TotalTokens: lastChunk.Delta.Usage.TotalTokens, + } + } + } + + // Update cost from final chunk if available + if len(acc.ImageStreamChunks) > 0 { + lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1] + if lastChunk.Cost != nil { + data.Cost = lastChunk.Cost + } + } + + // Update semantic cache debug from final chunk if available + if len(acc.ImageStreamChunks) > 0 { + lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1] + if lastChunk.SemanticCacheDebug != nil { + data.CacheDebug = lastChunk.SemanticCacheDebug + } + data.FinishReason = lastChunk.FinishReason + } + + return data, nil +} + +// processImageStreamingResponse processes an image streaming response +func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) { + a.logger.Debug("[streaming] processing image streaming response") + // Extract request ID from context + requestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + if !ok || requestID == "" { + // Log error but don't fail the request + return nil, fmt.Errorf("request-id not found in context or is empty") + } + _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + + isFinalChunk := bifrost.IsFinalChunk(ctx) + chunk := a.getImageStreamChunk() + chunk.Timestamp = time.Now() + chunk.ErrorDetails = bifrostErr + if bifrostErr != nil { + chunk.FinishReason = bifrost.Ptr("error") + } else if result != nil && result.ImageGenerationStreamResponse != nil { + // Create a deep copy of the delta to avoid pointing to stack memory + newDelta := &schemas.BifrostImageGenerationStreamResponse{ + ID: result.ImageGenerationStreamResponse.ID, + Type: result.ImageGenerationStreamResponse.Type, + Index: result.ImageGenerationStreamResponse.Index, + ChunkIndex: result.ImageGenerationStreamResponse.ChunkIndex, + PartialB64: result.ImageGenerationStreamResponse.PartialB64, + RevisedPrompt: result.ImageGenerationStreamResponse.RevisedPrompt, + Usage: result.ImageGenerationStreamResponse.Usage, + Error: result.ImageGenerationStreamResponse.Error, + ExtraFields: result.ImageGenerationStreamResponse.ExtraFields, + } + chunk.Delta = newDelta + chunk.ChunkIndex = result.ImageGenerationStreamResponse.ChunkIndex + chunk.ImageIndex = result.ImageGenerationStreamResponse.Index + + // Extract usage if available + if result.ImageGenerationStreamResponse.Usage != nil { + // Note: ImageUsage doesn't directly map to BifrostLLMUsage, but we can store it + // The actual usage will be extracted in processAccumulatedImageStreamingChunks + } + + if isFinalChunk { + if a.pricingManager != nil { + cost := a.pricingManager.CalculateCostWithCacheDebug(result) + chunk.Cost = bifrost.Ptr(cost) + } + chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug + chunk.FinishReason = bifrost.Ptr("completed") + } + } + + if addErr := a.addImageStreamChunk(requestID, chunk, isFinalChunk); addErr != nil { + return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr) + } + + // If this is the final chunk, process accumulated chunks asynchronously + // Use the IsComplete flag to prevent duplicate processing + if isFinalChunk { + shouldProcess := false + // Get the accumulator to check if processing has already been triggered + accumulator := a.getOrCreateStreamAccumulator(requestID) + accumulator.mu.Lock() + shouldProcess = !accumulator.IsComplete + // Mark as complete when we're about to process + if shouldProcess { + accumulator.IsComplete = true + } + accumulator.mu.Unlock() + if shouldProcess { + data, processErr := a.processAccumulatedImageStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error(fmt.Sprintf("failed to process accumulated chunks for request %s: %v", requestID, processErr)) + return nil, processErr + } + return &ProcessedStreamResponse{ + Type: StreamResponseTypeFinal, + RequestID: requestID, + StreamType: StreamTypeImage, + Provider: provider, + Model: model, + Data: data, + }, nil + } + + return nil, nil + } + + // This is going to be a delta response + data, processErr := a.processAccumulatedImageStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error(fmt.Sprintf("failed to process accumulated chunks for request %s: %v", requestID, processErr)) + return nil, processErr + } + + // This is not the final chunk, so we will send back the delta + return &ProcessedStreamResponse{ + Type: StreamResponseTypeDelta, + RequestID: requestID, + StreamType: StreamTypeImage, + Provider: provider, + Model: model, + Data: data, + }, nil +} diff --git a/framework/streaming/types.go b/framework/streaming/types.go index 29bb62dfc..d1da471e3 100644 --- a/framework/streaming/types.go +++ b/framework/streaming/types.go @@ -13,6 +13,7 @@ const ( StreamTypeText StreamType = "text.completion" StreamTypeChat StreamType = "chat.completion" StreamTypeAudio StreamType = "audio.speech" + StreamTypeImage StreamType = "image.generation" StreamTypeTranscription StreamType = "audio.transcription" StreamTypeResponses StreamType = "responses" ) @@ -42,6 +43,7 @@ type AccumulatedData struct { Cost *float64 AudioOutput *schemas.BifrostSpeechResponse TranscriptionOutput *schemas.BifrostTranscriptionResponse + ImageGenerationOutput *schemas.BifrostImageGenerationResponse FinishReason *string RawResponse *string } @@ -98,6 +100,19 @@ type ResponsesStreamChunk struct { RawResponse *string } +// ImageStreamChunk represents a single image streaming chunk +type ImageStreamChunk struct { + Timestamp time.Time // When chunk was received + Delta *schemas.BifrostImageGenerationStreamResponse // The actual stream response + FinishReason *string // If this is the final chunk + ChunkIndex int // Index of the chunk in the stream + ImageIndex int // Index of the image in the stream + ErrorDetails *schemas.BifrostError // Error if any + Cost *float64 // Cost in dollars from pricing plugin + SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available + TokenUsage *schemas.BifrostLLMUsage // Token usage if available +} + // StreamAccumulator manages accumulation of streaming chunks type StreamAccumulator struct { RequestID string @@ -106,6 +121,7 @@ type StreamAccumulator struct { ResponsesStreamChunks []*ResponsesStreamChunk TranscriptionStreamChunks []*TranscriptionStreamChunk AudioStreamChunks []*AudioStreamChunk + ImageStreamChunks []*ImageStreamChunk IsComplete bool FinalTimestamp time.Time mu sync.Mutex @@ -257,9 +273,21 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { ModelRequested: p.Model, Latency: p.Data.Latency, } - if p.RawRequest != nil { + if p.RawRequest != nil { resp.TranscriptionResponse.ExtraFields.RawRequest = p.RawRequest + } + case StreamTypeImage: + imageResp := p.Data.ImageGenerationOutput + if imageResp == nil { + imageResp = &schemas.BifrostImageGenerationResponse{} + } + resp.ImageGenerationResponse = imageResp + resp.ImageGenerationResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ImageGenerationStreamRequest, + Provider: p.Provider, + ModelRequested: p.Model, + Latency: p.Data.Latency, } - } + return resp } diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index e69de29bb..0fa959dc5 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -0,0 +1 @@ +- feat: added semantic caching support for image generation diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 0da4592ff..a7c37562f 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" @@ -195,6 +196,19 @@ var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{ DataType: vectorstore.VectorStorePropertyTypeBoolean, Description: "Whether the cache entry was created by the BifrostSemanticCachePlugin", }, + // image specific fields + "image_urls": { + DataType: vectorstore.VectorStorePropertyTypeStringArray, + Description: "Cached image URLs from image generation responses", + }, + "image_b64": { + DataType: vectorstore.VectorStorePropertyTypeStringArray, + Description: "Cached base64 image data from image generation responses", + }, + "revised_prompts": { + DataType: vectorstore.VectorStorePropertyTypeStringArray, + Description: "Revised prompts from image generation responses", + }, } type PluginAccount struct { @@ -377,7 +391,7 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR ctx.SetValue(requestIDKey, requestID) ctx.SetValue(requestModelKey, model) ctx.SetValue(requestProviderKey, provider) - + performDirectSearch, performSemanticSearch := true, true if (*ctx).Value(CacheTypeKey) != nil { cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) @@ -730,3 +744,38 @@ func (plugin *Plugin) ClearCacheForRequestID(requestID string) error { return nil } + +// getImageCacheKey generates an image-specific cache key using xxhash. +// Hash components: prompt + size + quality + style + n +// Returns a prefixed hash string in format "img_" +func (plugin *Plugin) getImageCacheKey(req *schemas.BifrostImageGenerationRequest) string { + if req == nil || req.Input == nil { + return "" + } + + h := xxhash.New() + h.WriteString(req.Input.Prompt) + + if req.Params != nil { + if req.Params.Size != nil { + h.WriteString(*req.Params.Size) + } + if req.Params.Quality != nil { + h.WriteString(*req.Params.Quality) + } + if req.Params.Style != nil { + h.WriteString(*req.Params.Style) + } + if req.Params.N != nil { + h.WriteString(fmt.Sprintf("%d", *req.Params.N)) + } + if req.Params.ResponseFormat != nil { + h.WriteString(*req.Params.ResponseFormat) + } + if req.Params.User != nil { + h.WriteString(*req.Params.User) + } + } + + return fmt.Sprintf("img_%x", h.Sum64()) +} diff --git a/plugins/semanticcache/search.go b/plugins/semanticcache/search.go index 1872e3624..407fbf48b 100644 --- a/plugins/semanticcache/search.go +++ b/plugins/semanticcache/search.go @@ -297,7 +297,7 @@ func (plugin *Plugin) buildStreamingResponseFromResult(ctx *schemas.BifrostConte // Mark cache-hit once to avoid concurrent ctx writes ctx.SetValue(isCacheHitKey, true) ctx.SetValue(cacheHitTypeKey, cacheType) - + // Create stream channel streamChan := make(chan *schemas.BifrostStream) @@ -354,11 +354,12 @@ func (plugin *Plugin) buildStreamingResponseFromResult(ctx *schemas.BifrostConte // Send chunk to stream streamChan <- &schemas.BifrostStream{ - BifrostTextCompletionResponse: cachedResponse.TextCompletionResponse, - BifrostChatResponse: cachedResponse.ChatResponse, - BifrostResponsesStreamResponse: cachedResponse.ResponsesStreamResponse, - BifrostSpeechStreamResponse: cachedResponse.SpeechStreamResponse, - BifrostTranscriptionStreamResponse: cachedResponse.TranscriptionStreamResponse, + BifrostTextCompletionResponse: cachedResponse.TextCompletionResponse, + BifrostChatResponse: cachedResponse.ChatResponse, + BifrostResponsesStreamResponse: cachedResponse.ResponsesStreamResponse, + BifrostSpeechStreamResponse: cachedResponse.SpeechStreamResponse, + BifrostTranscriptionStreamResponse: cachedResponse.TranscriptionStreamResponse, + BifrostImageGenerationStreamResponse: cachedResponse.ImageGenerationStreamResponse, } } }() diff --git a/plugins/semanticcache/stream.go b/plugins/semanticcache/stream.go index bd9f19e05..afe892b35 100644 --- a/plugins/semanticcache/stream.go +++ b/plugins/semanticcache/stream.go @@ -116,6 +116,13 @@ func (plugin *Plugin) processAccumulatedStream(ctx context.Context, requestID st if accumulator.Chunks[i].Response.TranscriptionStreamResponse != nil { return accumulator.Chunks[i].Response.TranscriptionStreamResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.TranscriptionStreamResponse.ExtraFields.ChunkIndex } + if accumulator.Chunks[i].Response.ImageGenerationStreamResponse != nil { + // For image generation, sort by Index first, then ChunkIndex + if accumulator.Chunks[i].Response.ImageGenerationStreamResponse.Index != accumulator.Chunks[j].Response.ImageGenerationStreamResponse.Index { + return accumulator.Chunks[i].Response.ImageGenerationStreamResponse.Index < accumulator.Chunks[j].Response.ImageGenerationStreamResponse.Index + } + return accumulator.Chunks[i].Response.ImageGenerationStreamResponse.ChunkIndex < accumulator.Chunks[j].Response.ImageGenerationStreamResponse.ChunkIndex + } return false }) diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index 08b37e2d1..49670e079 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -85,6 +85,13 @@ func (plugin *Plugin) generateEmbedding(ctx context.Context, text string) ([]flo // - string: Hexadecimal representation of the xxhash // - error: Any error that occurred during request normalization or hashing func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest) (string, error) { + // Special handling for image generation (hash = prompt + size + quality + style + n) + if req.RequestType == schemas.ImageGenerationRequest || req.RequestType == schemas.ImageGenerationStreamRequest { + if req.ImageGenerationRequest != nil { + return plugin.getImageCacheKey(req.ImageGenerationRequest), nil + } + return "", fmt.Errorf("image generation request is nil") + } // Create a hash input structure that includes both input and parameters hashInput := struct { Input interface{} `json:"input"` @@ -165,6 +172,10 @@ func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest) (stri if req.TranscriptionRequest != nil && req.TranscriptionRequest.Params != nil { plugin.extractTranscriptionParametersToMetadata(req.TranscriptionRequest.Params, metadata) } + case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: + if req.ImageGenerationRequest != nil && req.ImageGenerationRequest.Params != nil { + plugin.extractImageGenerationParametersToMetadata(req.ImageGenerationRequest.Params, metadata) + } } switch { @@ -322,6 +333,16 @@ func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest) (stri // Skip semantic caching for transcription requests return "", "", fmt.Errorf("transcription requests are not supported for semantic caching") + case req.ImageGenerationRequest != nil: + if req.ImageGenerationRequest.Input == nil || req.ImageGenerationRequest.Input.Prompt == "" { + return "", "", fmt.Errorf("no prompt found in image generation request") + } + metadataHash, err := getMetadataHash(metadata) + if err != nil { + return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) + } + return normalizeText(req.ImageGenerationRequest.Input.Prompt), metadataHash, nil + default: return "", "", fmt.Errorf("unsupported input type for semantic caching") } @@ -371,6 +392,29 @@ func (plugin *Plugin) addSingleResponse(ctx context.Context, responseID string, metadata["response"] = string(responseData) metadata["stream_chunks"] = []string{} + // image specific metadata + if res.ImageGenerationResponse != nil { + var imageURLs []string + var imageB64 []string + var revisedPrompts []string + + for _, img := range res.ImageGenerationResponse.Data { + if img.URL != "" { + imageURLs = append(imageURLs, img.URL) + } + if img.B64JSON != "" { + imageB64 = append(imageB64, img.B64JSON) + } + if img.RevisedPrompt != "" { + revisedPrompts = append(revisedPrompts, img.RevisedPrompt) + } + } + + metadata["image_urls"] = imageURLs + metadata["image_b64"] = imageB64 + metadata["revised_prompts"] = revisedPrompts + } + // Store unified entry using new VectorStore interface if err := plugin.store.Add(ctx, plugin.config.VectorStoreNamespace, responseID, embedding, metadata); err != nil { return fmt.Errorf("failed to store unified cache entry: %w", err) @@ -469,6 +513,8 @@ func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) interface{ return req.EmbeddingRequest.Input case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: return req.TranscriptionRequest.Input + case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: + return req.ImageGenerationRequest.Input default: return nil } @@ -605,6 +651,13 @@ func (plugin *Plugin) getNormalizedInputForCaching(req *schemas.BifrostRequest) return copiedInput case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: return req.TranscriptionRequest.Input + case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: + if req.ImageGenerationRequest != nil && req.ImageGenerationRequest.Input != nil { + return &schemas.ImageGenerationInput{ + Prompt: normalizeText(req.ImageGenerationRequest.Input.Prompt), + } + } + return nil default: return nil } @@ -904,6 +957,34 @@ func (plugin *Plugin) extractTranscriptionParametersToMetadata(params *schemas.T } } +// extractImageGenerationParametersToMetadata extracts Image Generation parameters into metadata map +func (plugin *Plugin) extractImageGenerationParametersToMetadata(params *schemas.ImageGenerationParameters, metadata map[string]interface{}) { + if params == nil { + return + } + if params.N != nil { + metadata["n"] = *params.N + } + if params.Size != nil { + metadata["size"] = *params.Size + } + if params.Quality != nil { + metadata["quality"] = *params.Quality + } + if params.Style != nil { + metadata["style"] = *params.Style + } + if params.ResponseFormat != nil { + metadata["response_format"] = *params.ResponseFormat + } + if params.User != nil { + metadata["user"] = *params.User + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } +} + func (plugin *Plugin) isConversationHistoryThresholdExceeded(req *schemas.BifrostRequest) bool { switch { case req.ChatRequest != nil: diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 546d1285c..6b758ee20 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -5,6 +5,7 @@ package handlers import ( "bufio" "context" + "encoding/json" "fmt" "io" @@ -17,6 +18,7 @@ import ( "github.com/bytedance/sonic" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" @@ -136,6 +138,19 @@ var speechParamsKnownFields = map[string]bool{ "speed": true, } +var imageParamsKnownFields = map[string]bool{ + "model": true, + "prompt": true, + "fallbacks": true, + "stream": true, + "n": true, + "size": true, + "quality": true, + "style": true, + "response_format": true, + "user": true, +} + var transcriptionParamsKnownFields = map[string]bool{ "model": true, "file": true, @@ -204,6 +219,12 @@ type ResponsesRequestInput struct { ResponsesRequestInputArray []schemas.ResponsesMessage } +type ImageGenerationHTTPRequest struct { + *schemas.ImageGenerationInput + *schemas.ImageGenerationParameters + BifrostParams +} + // UnmarshalJSON unmarshals the responses request input func (r *ResponsesRequestInput) UnmarshalJSON(data []byte) error { var str string @@ -346,6 +367,7 @@ func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...lib. r.POST("/v1/embeddings", lib.ChainMiddlewares(h.embeddings, middlewares...)) r.POST("/v1/audio/speech", lib.ChainMiddlewares(h.speech, middlewares...)) r.POST("/v1/audio/transcriptions", lib.ChainMiddlewares(h.transcription, middlewares...)) + r.POST("/v1/images/generations", lib.ChainMiddlewares(h.imageGeneration, middlewares...)) } // listModels handles GET /v1/models - Process list models requests @@ -1050,7 +1072,7 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge // Convert response to JSON chunkJSON, err := sonic.Marshal(chunk) if err != nil { - logger.Warn(fmt.Sprintf("Failed to marshal streaming response: %v", err)) + logger.Warn(fmt.Sprintf("Failed to marshal streaming response: %v, chunk: %v", err, chunk)) continue } @@ -1179,3 +1201,95 @@ func (h *CompletionHandler) validateAudioFile(fileHeader *multipart.FileHeader) return nil } + +// imageGeneration handles POST /v1/images/generations - Processes image generation requests +func (h *CompletionHandler) imageGeneration(ctx *fasthttp.RequestCtx) { + + var req ImageGenerationHTTPRequest + + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) + return + } + + // Parse model format provider/model + provider, modelName := schemas.ParseModelString(req.Model, "") + if provider == "" || modelName == "" { + SendError(ctx, fasthttp.StatusBadRequest, "model should be in provider/model format") + return + } + + if req.ImageGenerationInput == nil || req.Prompt == "" { + SendError(ctx, fasthttp.StatusBadRequest, "prompt can not be empty") + return + } + // Extract extra params + if req.ImageGenerationParameters == nil { + req.ImageGenerationParameters = &schemas.ImageGenerationParameters{} + } + + extraParams, err := extractExtraParams(ctx.PostBody(), imageParamsKnownFields) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to extract extra params: %v", err)) + // Continue without extra params + } else { + req.ImageGenerationParameters.ExtraParams = extraParams + } + // Parse fallbacks + fallbacks, err := parseFallbacks(req.Fallbacks) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + // Create Bifrost request + bifrostReq := &schemas.BifrostImageGenerationRequest{ + Provider: schemas.ModelProvider(provider), + Model: modelName, + Input: &schemas.ImageGenerationInput{Prompt: req.Prompt}, + Params: req.ImageGenerationParameters, + Fallbacks: fallbacks, + } + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + cancel() + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") + return + } + + // Handle streaming image generation + if req.BifrostParams.Stream != nil && *req.BifrostParams.Stream { + if req.ResponseFormat != nil && *req.ResponseFormat == "url" { + cancel() + SendError(ctx, fasthttp.StatusBadRequest, "streaming images must be requested in base64") + return + } + h.handleStreamingImageGeneration(ctx, bifrostReq, bifrostCtx, cancel) + return + } + defer cancel() + + // Execute request + resp, bifrostErr := h.client.ImageGenerationRequest(*bifrostCtx, bifrostReq) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + SendJSON(ctx, resp) +} + +// handleStreamingImageGeneration handles streaming image generation requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingImageGeneration(ctx *fasthttp.RequestCtx, req *schemas.BifrostImageGenerationRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { + // Use the cancellable context from ConvertToBifrostContext + // See router.go for detailed explanation of why we need a cancellable context + streamCtx := *bifrostCtx + + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.ImageGenerationStreamRequest(streamCtx, req) + } + + h.handleStreamingResponse(ctx, getStream, cancel) +} diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 5dd81d616..cc08b7e27 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -57,6 +57,8 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ r.Model = setAzureModelName(r.Model, deploymentIDStr) case *openai.OpenAIEmbeddingRequest: r.Model = setAzureModelName(r.Model, deploymentIDStr) + case *openai.OpenAIImageGenerationRequest: + r.Model = setAzureModelName(r.Model, deploymentIDStr) case *schemas.BifrostListModelsRequest: r.Provider = schemas.Azure } @@ -373,6 +375,55 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) }) } + // Image Generation endpoint + for _, path := range []string{ + "/v1/images/generations", + "/images/generations", + "/openai/deployments/{deployment-id}/images/generations", + } { + routes = append(routes, RouteConfig{ + Type: RouteConfigTypeOpenAI, + Path: pathPrefix + path, + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &openai.OpenAIImageGenerationRequest{} + }, + RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + if imageGenReq, ok := req.(*openai.OpenAIImageGenerationRequest); ok { + return &schemas.BifrostRequest{ + ImageGenerationRequest: imageGenReq.ToBifrostImageGenerationRequest(), + }, nil + } + return nil, errors.New("invalid image generation request type") + }, + ImageGenerationResponseConverter: func(ctx *context.Context, resp *schemas.BifrostImageGenerationResponse) (interface{}, error) { + if resp.ExtraFields.Provider == schemas.OpenAI { + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + } + return resp, nil + }, + ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + return err + }, + StreamConfig: &StreamConfig{ + ImageGenerationStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostImageGenerationStreamResponse) (string, interface{}, error) { + if resp.ExtraFields.Provider == schemas.OpenAI { + if resp.ExtraFields.RawResponse != nil { + return resp.Type, resp.ExtraFields.RawResponse, nil + } + } + return resp.Type, resp, nil + }, + ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + return err + }, + }, + PreCallback: AzureEndpointPreHook(handlerStore), + }) + } + return routes } diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 974ff94c3..4cfaaf35f 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -130,6 +130,14 @@ type SpeechStreamResponseConverter func(ctx *context.Context, resp *schemas.Bifr // It takes a BifrostTranscriptionStreamResponse and returns the event type and the streaming format expected by the specific integration. type TranscriptionStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostTranscriptionStreamResponse) (string, interface{}, error) +// ImageGenerationResponseConverter is a function that converts BifrostImageGenerationResponse to integration-specific format. +// It takes a BifrostImageGenerationResponse and returns the format expected by the specific integration. +type ImageGenerationResponseConverter func(ctx *context.Context, resp *schemas.BifrostImageGenerationResponse) (interface{}, error) + +// ImageGenerationStreamResponseConverter is a function that converts BifrostImageGenerationStreamResponse to integration-specific streaming format. +// It takes a BifrostImageGenerationStreamResponse and returns the event type and the streaming format expected by the specific integration. +type ImageGenerationStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostImageGenerationStreamResponse) (string, interface{}, error) + // ErrorConverter is a function that converts BifrostError to integration-specific format. // It takes a BifrostError and returns the format expected by the specific integration. type ErrorConverter func(ctx *context.Context, err *schemas.BifrostError) interface{} @@ -181,6 +189,7 @@ type StreamConfig struct { ResponsesStreamResponseConverter ResponsesStreamResponseConverter // Function to convert BifrostResponsesResponse to streaming format SpeechStreamResponseConverter SpeechStreamResponseConverter // Function to convert BifrostSpeechResponse to streaming format TranscriptionStreamResponseConverter TranscriptionStreamResponseConverter // Function to convert BifrostTranscriptionResponse to streaming format + ImageGenerationStreamResponseConverter ImageGenerationStreamResponseConverter // Function to convert BifrostImageGenerationStreamResponse to streaming format ErrorConverter StreamErrorConverter // Function to convert BifrostError to streaming error format } @@ -209,6 +218,7 @@ type RouteConfig struct { EmbeddingResponseConverter EmbeddingResponseConverter // Function to convert BifrostEmbeddingResponse to integration format (SHOULD NOT BE NIL) SpeechResponseConverter SpeechResponseConverter // Function to convert BifrostSpeechResponse to integration format (SHOULD NOT BE NIL) TranscriptionResponseConverter TranscriptionResponseConverter // Function to convert BifrostTranscriptionResponse to integration format (SHOULD NOT BE NIL) + ImageGenerationResponseConverter ImageGenerationResponseConverter // Function to convert BifrostImageGenerationResponse to integration format (SHOULD NOT BE NIL) ErrorConverter ErrorConverter // Function to convert BifrostError to integration format (SHOULD NOT BE NIL) StreamConfig *StreamConfig // Optional: Streaming configuration (if nil, streaming not supported) PreCallback PreRequestCallback // Optional: called after parsing but before Bifrost processing @@ -566,6 +576,29 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.TranscriptionResponseConverter(bifrostCtx, transcriptionResponse) + case bifrostReq.ImageGenerationRequest != nil: + imageGenerationResponse, bifrostErr := g.client.ImageGenerationRequest(requestCtx, bifrostReq.ImageGenerationRequest) + if bifrostErr != nil { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) + return + } + + // Execute post-request callback if configured + // This is typically used for response modification or additional processing + if config.PostCallback != nil { + if err := config.PostCallback(ctx, req, imageGenerationResponse); err != nil { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to execute post-request callback")) + return + } + } + + if imageGenerationResponse == nil { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Bifrost response is nil after post-request callback")) + return + } + + // Convert Bifrost response to integration-specific format and send + response, err = config.ImageGenerationResponseConverter(bifrostCtx, imageGenerationResponse) default: g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid request type")) return @@ -617,6 +650,8 @@ func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config stream, bifrostErr = g.client.SpeechStreamRequest(streamCtx, bifrostReq.SpeechRequest) } else if bifrostReq.TranscriptionRequest != nil { stream, bifrostErr = g.client.TranscriptionStreamRequest(streamCtx, bifrostReq.TranscriptionRequest) + } else if bifrostReq.ImageGenerationRequest != nil { + stream, bifrostErr = g.client.ImageGenerationStreamRequest(streamCtx, bifrostReq.ImageGenerationRequest) } // Get the streaming channel from Bifrost @@ -801,6 +836,8 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *co eventType, convertedResponse, err = config.StreamConfig.SpeechStreamResponseConverter(bifrostCtx, chunk.BifrostSpeechStreamResponse) case chunk.BifrostTranscriptionStreamResponse != nil: eventType, convertedResponse, err = config.StreamConfig.TranscriptionStreamResponseConverter(bifrostCtx, chunk.BifrostTranscriptionStreamResponse) + case chunk.BifrostImageGenerationStreamResponse != nil: + eventType, convertedResponse, err = config.StreamConfig.ImageGenerationStreamResponseConverter(bifrostCtx, chunk.BifrostImageGenerationStreamResponse) default: requestType := safeGetRequestType(chunk) convertedResponse, err = nil, fmt.Errorf("no response converter found for request type: %s", requestType) diff --git a/transports/changelog.md b/transports/changelog.md index e69de29bb..f50725785 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -0,0 +1 @@ +feat: added http handlers for image generation endpoints \ No newline at end of file diff --git a/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx b/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx index 0959a9193..78c51c644 100644 --- a/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx +++ b/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx @@ -32,6 +32,8 @@ const PROVIDER_ENDPOINTS: Partial = [ { key: "speech_stream", label: "Speech Stream" }, { key: "transcription", label: "Transcription" }, { key: "transcription_stream", label: "Transcription Stream" }, + { key: "image_generation", label: "Image Generation" }, + { key: "image_generation_stream", label: "Image Generation Stream" }, ]; export function AllowedRequestsFields({ control, namePrefix = "allowed_requests", providerType }: AllowedRequestsFieldsProps) { diff --git a/ui/components/chat/ImageMessage.tsx b/ui/components/chat/ImageMessage.tsx new file mode 100644 index 000000000..e2e335d9a --- /dev/null +++ b/ui/components/chat/ImageMessage.tsx @@ -0,0 +1,57 @@ +import React from 'react'; +import { Card } from '@/components/ui/card'; +import { Skeleton } from '@/components/ui/skeleton'; + +interface ImageMessageProps { + images: Array<{ + url?: string; + b64_json?: string; + revised_prompt?: string; + index: number; + }>; + isStreaming?: boolean; + streamProgress?: number; // 0-100 +} + +export const ImageMessage: React.FC = ({ + images, + isStreaming, + streamProgress +}) => { + return ( +
+ {images.map((img, idx) => ( + + {isStreaming && !img.url && !img.b64_json ? ( +
+ +
+ Loading... {streamProgress}% +
+
+) : (img.url || img.b64_json) ? ( + <> + {img.revised_prompt + {img.revised_prompt && ( +
+ {img.revised_prompt} +
+ )} + +) : ( +
+

+ Image unavailable +

+
+)} +
+ ))} +
+ ); +}; \ No newline at end of file diff --git a/ui/hooks/useImageStream.ts b/ui/hooks/useImageStream.ts new file mode 100644 index 000000000..cb94f42e0 --- /dev/null +++ b/ui/hooks/useImageStream.ts @@ -0,0 +1,292 @@ +"use client"; + +import { useState, useCallback, useRef, useEffect } from "react"; +import { getEndpointUrl } from "@/lib/utils/port"; +import { getTokenFromStorage } from "@/lib/store/apis/baseApi"; + +// Matches backend BifrostImageGenerationStreamResponse +interface ImageStreamChunk { + id: string; + type: string; // "image_generation.partial_image", "image_generation.completed", "error" + index: number; // Which image (0-N) + chunk_index: number; // Chunk order within image + partial_b64?: string; // Base64 chunk + revised_prompt?: string; // On first chunk + usage?: { + prompt_tokens: number; + total_tokens: number; + }; + error?: { + message: string; + code?: string; + }; +} + +export interface StreamedImage { + url?: string; + b64_json?: string; + revised_prompt?: string; + index: number; +} + +interface ImageStreamState { + images: StreamedImage[]; + isStreaming: boolean; + progress: number; // 0-100 + error: string | null; +} + +interface UseImageStreamOptions { + onComplete?: (images: StreamedImage[]) => void; + onError?: (error: string) => void; +} + +interface ImageStreamRequest { + model: string; + prompt: string; + n?: number; + size?: string; + quality?: string; + style?: string; + response_format?: string; +} + +export function useImageStream(options: UseImageStreamOptions = {}) { + const [state, setState] = useState({ + images: [], + isStreaming: false, + progress: 0, + error: null, + }); + + const abortControllerRef = useRef(null); + const imageChunksRef = useRef; revisedPrompt?: string }>>(new Map()); + const totalChunksReceivedRef = useRef(0); + const expectedImagesRef = useRef(1); + + // Reset state for new request + const reset = useCallback(() => { + imageChunksRef.current.clear(); + totalChunksReceivedRef.current = 0; + setState({ + images: [], + isStreaming: false, + progress: 0, + error: null, + }); + }, []); + + // Cancel ongoing stream + const cancel = useCallback(() => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + setState((prev) => ({ ...prev, isStreaming: false })); + }, []); + + // Build complete image from accumulated chunks + const buildImageFromChunks = useCallback((imageIndex: number): StreamedImage | null => { + const imageData = imageChunksRef.current.get(imageIndex); + if (!imageData) return null; + + // Sort chunks by chunk_index and concatenate + const sortedChunks = Array.from(imageData.chunks.entries()) + .sort(([a], [b]) => a - b) + .map(([, chunk]) => chunk); + + const fullB64 = sortedChunks.join(""); + + return { + b64_json: fullB64, + revised_prompt: imageData.revisedPrompt, + index: imageIndex, + }; + }, []); + + // Process incoming chunk + const processChunk = useCallback( + (chunk: ImageStreamChunk) => { + const { index, chunk_index, partial_b64, revised_prompt, type, error } = chunk; + + // Handle errors + if (type === "error" || error) { + const errorMsg = error?.message || "Unknown streaming error"; + setState((prev) => ({ ...prev, error: errorMsg, isStreaming: false })); + options.onError?.(errorMsg); + return; + } + + // Initialize image data if needed + if (!imageChunksRef.current.has(index)) { + imageChunksRef.current.set(index, { chunks: new Map() }); + } + + const imageData = imageChunksRef.current.get(index)!; + + // Store revised prompt (usually on first chunk) + if (revised_prompt) { + imageData.revisedPrompt = revised_prompt; + } + + // Store chunk data + if (partial_b64) { + imageData.chunks.set(chunk_index, partial_b64); + totalChunksReceivedRef.current++; + } + + // Calculate progress (rough estimate based on chunks received) + // Assuming ~10 chunks per image on average + const estimatedTotalChunks = expectedImagesRef.current * 10; + const progress = Math.min(95, Math.round((totalChunksReceivedRef.current / estimatedTotalChunks) * 100)); + + // Handle completion + if (type === "image_generation.completed") { + const completedImage = buildImageFromChunks(index); + + setState((prev) => { + const newImages = [...prev.images]; + if (completedImage) { + // Replace or add the completed image + const existingIdx = newImages.findIndex((img) => img.index === index); + if (existingIdx >= 0) { + newImages[existingIdx] = completedImage; + } else { + newImages.push(completedImage); + } + } + + // Check if all images are complete + const allComplete = newImages.length >= expectedImagesRef.current; + + if (allComplete) { + options.onComplete?.(newImages); + } + + return { + ...prev, + images: newImages.sort((a, b) => a.index - b.index), + isStreaming: !allComplete, + progress: allComplete ? 100 : progress, + }; + }); + } else { + // Update progress during streaming + setState((prev) => ({ ...prev, progress })); + } + }, + [buildImageFromChunks, options] + ); + + // Start streaming request + const stream = useCallback( + async (request: ImageStreamRequest) => { + cancel(); + reset(); + expectedImagesRef.current = request.n || 1; + + setState((prev) => ({ + ...prev, + isStreaming: true, + error: null, + images: Array.from({ length: expectedImagesRef.current }, (_, i) => ({ index: i })), + })); + + abortControllerRef.current = new AbortController(); + + try { + const token = await getTokenFromStorage(); + const url = getEndpointUrl("/v1/images/generations"); + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...(token ? { Authorization: `Bearer ${token}` } : {}), + }, + body: JSON.stringify({ + ...request, + stream: true, + }), + signal: abortControllerRef.current.signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || `HTTP ${response.status}`); + } + + if (!response.body) { + throw new Error("No response body"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + + // Process SSE lines + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; // Keep incomplete line in buffer + + for (const line of lines) { + const trimmed = line.trim(); + + if (trimmed.startsWith("data: ")) { + const data = trimmed.slice(6); + + if (data === "[DONE]") { + setState((prev) => ({ ...prev, isStreaming: false, progress: 100 })); + return; + } + + try { + const chunk: ImageStreamChunk = JSON.parse(data); + processChunk(chunk); + } catch (parseError) { + console.error("Failed to parse SSE chunk:", parseError); + } + } + } + } + + // Stream ended + setState((prev) => ({ ...prev, isStreaming: false })); + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + // Cancelled by user + return; + } + + const errorMsg = err instanceof Error ? err.message : "Stream failed"; + setState((prev) => ({ ...prev, error: errorMsg, isStreaming: false })); + options.onError?.(errorMsg); + } + }, + [cancel, reset, processChunk, options] + ); + + // Cleanup on unmount + useEffect(() => { + return () => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + }; + }, []); + + return { + ...state, + stream, + cancel, + reset, + }; +} \ No newline at end of file diff --git a/ui/lib/constants/logs.ts b/ui/lib/constants/logs.ts index 8d9cc066e..1b086aaeb 100644 --- a/ui/lib/constants/logs.ts +++ b/ui/lib/constants/logs.ts @@ -37,6 +37,8 @@ export const RequestTypes = [ "speech_stream", "transcription", "transcription_stream", + "image_generation", + "image_generation_stream", ] as const; export const ProviderLabels: Record = { @@ -101,6 +103,8 @@ export const RequestTypeLabels = { speech_stream: "Speech Stream", transcription: "Transcription", transcription_stream: "Transcription Stream", + image_generation: "Image Generation", + image_generation_stream: "Image Generation Stream", } as const; export const RequestTypeColors = { @@ -128,6 +132,8 @@ export const RequestTypeColors = { speech_stream: "bg-pink-100 text-pink-800", transcription: "bg-orange-100 text-orange-800", transcription_stream: "bg-lime-100 text-lime-800", + image_generation: "bg-indigo-100 text-indigo-800", + image_generation_stream: "bg-purple-100 text-purple-800", } as const; export type Status = (typeof Statuses)[number]; diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 14920ff2b..144873947 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -131,7 +131,9 @@ export type RequestType = | "speech" | "speech_stream" | "transcription" - | "transcription_stream"; + | "transcription_stream" + | "image_generation" + | "image_generation_stream"; // AllowedRequests matching Go's schemas.AllowedRequests export interface AllowedRequests { @@ -146,6 +148,8 @@ export interface AllowedRequests { speech_stream: boolean; transcription: boolean; transcription_stream: boolean; + image_generation: boolean; + image_generation_stream: boolean; list_models: boolean; }