diff --git a/core/internal/testutil/tests.go b/core/internal/testutil/tests.go index 97a1d3681..ec7f6f9f5 100644 --- a/core/internal/testutil/tests.go +++ b/core/internal/testutil/tests.go @@ -34,6 +34,11 @@ func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context RunEnd2EndToolCallingTest, RunAutomaticFunctionCallingTest, RunWebSearchToolTest, + RunWebSearchToolStreamTest, + RunWebSearchToolWithDomainsTest, + RunWebSearchToolContextSizesTest, + RunWebSearchToolMultiTurnTest, + RunWebSearchToolMaxUsesTest, RunImageURLTest, RunImageBase64Test, RunMultipleImagesTest, diff --git a/core/internal/testutil/web_search_tool.go b/core/internal/testutil/web_search_tool.go index 9f3420df0..d328a07bd 100644 --- a/core/internal/testutil/web_search_tool.go +++ b/core/internal/testutil/web_search_tool.go @@ -167,3 +167,676 @@ func WebSearchExpectations() ResponseExpectations { ShouldHaveContent: true, } } + +// RunWebSearchToolStreamTest executes streaming web search test +func RunWebSearchToolStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.WebSearchTool { + t.Logf("Web search tool not supported for provider %s", testConfig.Provider) + return + } + + t.Run("WebSearchToolStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + responsesMessages := []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("What are the latest advancements in renewable energy? Use web search."), + } + + // Create web search tool with user location + webSearchTool := &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + UserLocation: &schemas.ResponsesToolWebSearchUserLocation{ + Type: bifrost.Ptr("approximate"), + Country: bifrost.Ptr("US"), + City: bifrost.Ptr("San Francisco"), + Region: bifrost.Ptr("California"), + Timezone: bifrost.Ptr("America/Los_Angeles"), + }, + }, + } + + request := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: responsesMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + MaxOutputTokens: bifrost.Ptr(1500), + }, + Fallbacks: testConfig.Fallbacks, + } + + retryConfig := StreamingRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "WebSearchToolStream", + ExpectedBehavior: map[string]interface{}{ + "should_stream_content": true, + "should_have_web_search_call": true, + "should_have_streaming_events": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + }, + } + + validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext, + func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesStreamRequest(bfCtx, request) + }, + func(responseChannel chan *schemas.BifrostStream) ResponsesStreamValidationResult { + var hasWebSearchCall, hasMessageContent bool + var webSearchQuery string + var searchSources []schemas.ResponsesWebSearchToolCallActionSearchSource + var chunkCount int + var errors []string + + streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + + for { + select { + case stream, ok := <-responseChannel: + if !ok { + goto ValidationComplete + } + if stream == nil { + continue + } + + chunkCount++ + + // Check streaming events for web_search_call and message content + if stream.BifrostResponsesStreamResponse != nil { + streamType := stream.BifrostResponsesStreamResponse.Type + + // Check for output_item.added with web_search_call + if streamType == schemas.ResponsesStreamResponseTypeOutputItemAdded { + if stream.BifrostResponsesStreamResponse.Item != nil { + if stream.BifrostResponsesStreamResponse.Item.Type != nil && + *stream.BifrostResponsesStreamResponse.Item.Type == schemas.ResponsesMessageTypeWebSearchCall { + hasWebSearchCall = true + t.Logf("✅ Found web_search_call in streaming event: %s", streamType) + + // Extract query and sources if available + if stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage != nil && + stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage.Action != nil { + action := stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage.Action + if action.ResponsesWebSearchToolCallAction != nil { + if action.ResponsesWebSearchToolCallAction.Query != nil { + webSearchQuery = *action.ResponsesWebSearchToolCallAction.Query + t.Logf("✅ Web search query: %s", webSearchQuery) + } + searchSources = append(searchSources, action.ResponsesWebSearchToolCallAction.Sources...) + } + } + } + } + } + + // Also check other web_search_call streaming events + if streamType == schemas.ResponsesStreamResponseTypeWebSearchCallInProgress || + streamType == schemas.ResponsesStreamResponseTypeWebSearchCallSearching || + streamType == schemas.ResponsesStreamResponseTypeWebSearchCallCompleted { + hasWebSearchCall = true + t.Logf("✅ Found web_search_call streaming event: %s", streamType) + } + + // Check for message text content in streaming deltas + if streamType == schemas.ResponsesStreamResponseTypeOutputTextDelta { + if stream.BifrostResponsesStreamResponse.Delta != nil && *stream.BifrostResponsesStreamResponse.Delta != "" { + hasMessageContent = true + t.Logf("✅ Found message text delta: %s", *stream.BifrostResponsesStreamResponse.Delta) + } + } + } + + case <-streamCtx.Done(): + t.Logf("⚠️ Stream timeout after %d chunks", chunkCount) + goto ValidationComplete + } + } + + ValidationComplete: + if len(searchSources) > 0 { + t.Logf("✅ Found %d search sources", len(searchSources)) + } + + // Validate streaming requirements + if !hasWebSearchCall { + errors = append(errors, "No web_search_call found in stream") + } + + if !hasMessageContent { + errors = append(errors, "No message content found in stream") + } + + if chunkCount < 3 { + errors = append(errors, "Too few streaming chunks received") + } + + return ResponsesStreamValidationResult{ + Passed: len(errors) == 0, + Errors: errors, + ReceivedData: hasWebSearchCall || hasMessageContent, + } + }, + ) + + require.True(t, validationResult.Passed, "Stream validation failed: %v", validationResult.Errors) + t.Logf("🎉 WebSearchToolStream test passed!") + }) +} + +// RunWebSearchToolWithDomainsTest tests web search with domain filtering +func RunWebSearchToolWithDomainsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.WebSearchTool { + t.Logf("Web search tool not supported for provider %s", testConfig.Provider) + return + } + + t.Run("WebSearchToolWithDomains", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + responsesMessages := []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("What is machine learning? Use web search tool."), + } + + // Create web search tool with domain filters + webSearchTool := &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + Filters: &schemas.ResponsesToolWebSearchFilters{ + AllowedDomains: []string{"wikipedia.org", "en.wikipedia.org"}, + }, + }, + } + + retryConfig := WebSearchRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "WebSearchToolWithDomains", + ExpectedBehavior: map[string]interface{}{ + "expected_tool_type": "web_search", + "domain_filters": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + }, + } + + expectations := WebSearchExpectations() + + responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: responsesMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + MaxOutputTokens: bifrost.Ptr(1200), + }, + Fallbacks: testConfig.Fallbacks, + } + + return client.ResponsesRequest(bfCtx, responsesReq) + } + + response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolWithDomains", responsesOperation) + + if err != nil { + t.Fatalf("❌ WebSearchToolWithDomains test failed: %s", GetErrorMessage(err)) + } + + require.NotNil(t, response, "Response should not be nil") + + // Validate web search was invoked and collect sources + webSearchCallFound := false + var sources []schemas.ResponsesWebSearchToolCallActionSearchSource + + if response.Output != nil { + for _, output := range response.Output { + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall { + webSearchCallFound = true + if output.ResponsesToolMessage != nil && output.ResponsesToolMessage.Action != nil { + action := output.ResponsesToolMessage.Action + if action.ResponsesWebSearchToolCallAction != nil { + sources = action.ResponsesWebSearchToolCallAction.Sources + t.Logf("✅ Found %d search sources", len(sources)) + } + } + } + } + } + + require.True(t, webSearchCallFound, "Web search call should be present") + + // Validate sources respect domain filters + if len(sources) > 0 { + ValidateWebSearchSources(t, sources, []string{"wikipedia.org", "en.wikipedia.org"}) + } + + t.Logf("🎉 WebSearchToolWithDomains test passed!") + }) +} + +// RunWebSearchToolContextSizesTest tests different search context sizes +func RunWebSearchToolContextSizesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.WebSearchTool { + t.Logf("Web search tool not supported for provider %s", testConfig.Provider) + return + } + + t.Run("WebSearchToolContextSizes", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + contextSizes := []string{"low", "medium", "high"} + + for _, size := range contextSizes { + size := size // Capture loop variable + t.Run("ContextSize_"+size, func(t *testing.T) { + responsesMessages := []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("What is quantum computing? Use web search."), + } + + webSearchTool := &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + SearchContextSize: &size, + }, + } + + retryConfig := WebSearchRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "WebSearchToolContextSize_" + size, + ExpectedBehavior: map[string]interface{}{ + "expected_tool_type": "web_search", + "context_size": size, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + "context_size": size, + }, + } + + expectations := WebSearchExpectations() + + responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: responsesMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + MaxOutputTokens: bifrost.Ptr(1500), + }, + Fallbacks: testConfig.Fallbacks, + } + + return client.ResponsesRequest(bfCtx, responsesReq) + } + + response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolContextSize", responsesOperation) + + if err != nil { + t.Fatalf("❌ WebSearchToolContextSize (%s) test failed: %s", size, GetErrorMessage(err)) + } + + require.NotNil(t, response, "Response should not be nil") + + webSearchCallFound := false + hasTextResponse := false + + if response.Output != nil { + for _, output := range response.Output { + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall { + webSearchCallFound = true + t.Logf("✅ Web search call with context size: %s", size) + } + + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage { + if output.Content != nil && len(output.Content.ContentBlocks) > 0 { + for _, block := range output.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + hasTextResponse = true + t.Logf("✅ Response length for %s context: %d chars", size, len(*block.Text)) + } + } + } + } + } + } + + require.True(t, webSearchCallFound, "Web search call should be present") + require.True(t, hasTextResponse, "Response should contain text") + + t.Logf("🎉 WebSearchToolContextSize (%s) test passed!", size) + }) + } + }) +} + +// RunWebSearchToolMultiTurnTest tests multi-turn conversation with web search +func RunWebSearchToolMultiTurnTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.WebSearchTool { + t.Logf("Web search tool not supported for provider %s", testConfig.Provider) + return + } + + t.Run("WebSearchToolMultiTurn", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + webSearchTool := &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}, + } + + // First turn + t.Log("🔄 Starting first turn...") + firstMessages := []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("What is renewable energy? Use web search tool."), + } + + retryConfig := WebSearchRetryConfig() + retryContext1 := TestRetryContext{ + ScenarioName: "WebSearchToolMultiTurn_Turn1", + ExpectedBehavior: map[string]interface{}{ + "expected_tool_type": "web_search", + "turn": 1, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + }, + } + + expectations := WebSearchExpectations() + + firstOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: firstMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + MaxOutputTokens: bifrost.Ptr(1500), + }, + Fallbacks: testConfig.Fallbacks, + } + + return client.ResponsesRequest(bfCtx, responsesReq) + } + + firstResponse, err := WithResponsesTestRetry(t, retryConfig, retryContext1, expectations, "WebSearchToolMultiTurn_Turn1", firstOperation) + + if err != nil { + t.Fatalf("❌ First turn failed: %s", GetErrorMessage(err)) + } + + require.NotNil(t, firstResponse, "First response should not be nil") + + // Validate first turn has web search + firstTurnHasWebSearch := false + if firstResponse.Output != nil { + for _, output := range firstResponse.Output { + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall { + firstTurnHasWebSearch = true + t.Logf("✅ First turn: Web search executed") + break + } + } + } + + require.True(t, firstTurnHasWebSearch, "First turn should have web search call") + + // Second turn - add first response to conversation history + t.Log("🔄 Starting second turn...") + secondMessages := append(firstMessages, firstResponse.Output...) + secondMessages = append(secondMessages, CreateBasicResponsesMessage("What are the main types of renewable energy?")) + + retryContext2 := TestRetryContext{ + ScenarioName: "WebSearchToolMultiTurn_Turn2", + ExpectedBehavior: map[string]interface{}{ + "expected_tool_type": "web_search", + "turn": 2, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + }, + } + + secondOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: secondMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + MaxOutputTokens: bifrost.Ptr(1500), + }, + Fallbacks: testConfig.Fallbacks, + } + + return client.ResponsesRequest(bfCtx, responsesReq) + } + + secondResponse, err := WithResponsesTestRetry(t, retryConfig, retryContext2, expectations, "WebSearchToolMultiTurn_Turn2", secondOperation) + + if err != nil { + t.Fatalf("❌ Second turn failed: %s", GetErrorMessage(err)) + } + + require.NotNil(t, secondResponse, "Second response should not be nil") + + // Validate second turn + secondTurnHasMessage := false + if secondResponse.Output != nil { + for _, output := range secondResponse.Output { + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage { + secondTurnHasMessage = true + t.Logf("✅ Second turn: Got response message") + break + } + } + } + + require.True(t, secondTurnHasMessage, "Second turn should have message response") + + t.Logf("🎉 WebSearchToolMultiTurn test passed!") + }) +} + +// RunWebSearchToolMaxUsesTest tests Anthropic-specific max uses parameter +func RunWebSearchToolMaxUsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.WebSearchTool { + t.Logf("Web search tool not supported for provider %s", testConfig.Provider) + return + } + + // This is Anthropic-specific functionality + if testConfig.Provider != "anthropic" { + t.Logf("Max uses parameter is Anthropic-specific, skipping for provider %s", testConfig.Provider) + return + } + + t.Run("WebSearchToolMaxUses", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + responsesMessages := []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("Compare the populations of Tokyo and New York City. Use web search."), + } + + // Create web search tool with max uses limit + maxUses := 3 + webSearchTool := &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + MaxUses: &maxUses, + }, + } + + retryConfig := WebSearchRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "WebSearchToolMaxUses", + ExpectedBehavior: map[string]interface{}{ + "expected_tool_type": "web_search", + "max_uses": maxUses, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + }, + } + + expectations := WebSearchExpectations() + + responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: responsesMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + MaxOutputTokens: bifrost.Ptr(2000), + }, + Fallbacks: testConfig.Fallbacks, + } + + return client.ResponsesRequest(bfCtx, responsesReq) + } + + response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolMaxUses", responsesOperation) + + if err != nil { + t.Fatalf("❌ WebSearchToolMaxUses test failed: %s", GetErrorMessage(err)) + } + + require.NotNil(t, response, "Response should not be nil") + + // Count web search calls + webSearchCallCount := 0 + if response.Output != nil { + for _, output := range response.Output { + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall { + webSearchCallCount++ + } + } + } + + t.Logf("✅ Web search called %d times (max: %d)", webSearchCallCount, maxUses) + require.True(t, webSearchCallCount <= maxUses, "Web search should not exceed max uses limit") + require.True(t, webSearchCallCount > 0, "Web search should be called at least once") + + t.Logf("🎉 WebSearchToolMaxUses test passed!") + }) +} + +// ValidateWebSearchSources validates web search sources structure and domain filtering +func ValidateWebSearchSources(t *testing.T, sources []schemas.ResponsesWebSearchToolCallActionSearchSource, allowedDomains []string) { + require.NotEmpty(t, sources, "Sources should not be empty") + + for i, source := range sources { + // Validate basic structure + require.NotEmpty(t, source.URL, "Source %d should have a URL", i+1) + + t.Logf(" Source %d: %s", i+1, source.URL) + + // If domain filters specified, validate sources match patterns + if len(allowedDomains) > 0 { + matchesFilter := false + for _, domain := range allowedDomains { + // Simple pattern matching for wildcard domains + // "wikipedia.org/*" matches any wikipedia.org URL + // "*.edu" matches any .edu domain + if matchesDomainPattern(source.URL, domain) { + matchesFilter = true + break + } + } + + if !matchesFilter { + t.Logf(" ⚠️ Source %d (%s) doesn't match allowed domain filters", i+1, source.URL) + } + } + } + + t.Logf("✅ Validated %d search sources", len(sources)) +} + +// matchesDomainPattern checks if a URL matches a domain pattern +func matchesDomainPattern(url, pattern string) bool { + // Simple pattern matching implementation + // "*.edu" matches URLs containing ".edu" + // "wikipedia.org/*" matches URLs containing "wikipedia.org" + + if len(pattern) > 0 && pattern[0] == '*' { + // Pattern like "*.edu" + suffix := pattern[1:] + return containsSubstring(url, suffix) + } + + if len(pattern) > 0 && pattern[len(pattern)-1] == '*' { + // Pattern like "wikipedia.org/*" + prefix := pattern[:len(pattern)-2] + return containsSubstring(url, prefix) + } + + // Exact match + return containsSubstring(url, pattern) +} + +// containsSubstring checks if s contains substr (case-insensitive) +func containsSubstring(s, substr string) bool { + s = toLower(s) + substr = toLower(substr) + return len(s) >= len(substr) && indexOfSubstring(s, substr) >= 0 +} + +// toLower converts string to lowercase +func toLower(s string) string { + result := make([]rune, len(s)) + for i, r := range s { + if r >= 'A' && r <= 'Z' { + result[i] = r + 32 + } else { + result[i] = r + } + } + return string(result) +} + +// indexOfSubstring finds index of substr in s, or -1 if not found +func indexOfSubstring(s, substr string) int { + if len(substr) == 0 { + return 0 + } + if len(substr) > len(s) { + return -1 + } + + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index 89e511ab2..292967398 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1306,22 +1306,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp shouldGenerateDeltas = true } } - case schemas.ResponsesMessageTypeWebSearchCall: - // Extract query from web search action - if bifrostResp.Item.ResponsesToolMessage.Action != nil && - bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil && - bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query != nil { - // Create input map with query - inputMap := map[string]interface{}{ - "query": *bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query, - } - if jsonBytes, err := json.Marshal(inputMap); err == nil { - argumentsJSON = string(jsonBytes) - shouldGenerateDeltas = true - } - } } - if shouldGenerateDeltas && argumentsJSON != "" { // Generate synthetic input_json_delta events by chunking the JSON var indexToUse *int @@ -1518,9 +1503,37 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp bifrostResp.Item.Type != nil && *bifrostResp.Item.Type == schemas.ResponsesMessageTypeWebSearchCall { - // Web search call complete - emit content_block_stop for query, then web_search_tool_result block + // Web search call complete - generate synthetic input_json_delta events, then emit content_block_stop var events []*AnthropicStreamEvent + // Extract query from web search action for synthetic delta generation + var queryJSON string + if bifrostResp.Item.ResponsesToolMessage != nil && + bifrostResp.Item.ResponsesToolMessage.Action != nil && + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil && + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query != nil { + + // Create input map with query + inputMap := map[string]interface{}{ + "query": *bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query, + } + if jsonBytes, err := json.Marshal(inputMap); err == nil { + queryJSON = string(jsonBytes) + } + } + + // Generate synthetic input_json_delta events if we have a query + if queryJSON != "" { + var indexToUse *int + if bifrostResp.OutputIndex != nil { + indexToUse = bifrostResp.OutputIndex + } else if bifrostResp.ContentIndex != nil { + indexToUse = bifrostResp.ContentIndex + } + deltaEvents := generateSyntheticInputJSONDeltas(queryJSON, indexToUse) + events = append(events, deltaEvents...) + } + // 1. Emit content_block_stop for the query block (server_tool_use) stopEvent := &AnthropicStreamEvent{ Type: AnthropicStreamEventTypeContentBlockStop, @@ -1799,11 +1812,11 @@ func (request *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx context.Co for _, tool := range request.Tools { if tool.Type != nil && (*tool.Type == AnthropicToolTypeComputer20250124 || *tool.Type == AnthropicToolTypeComputer20251124) { params.Truncation = schemas.Ptr("auto") - break + } else if tool.Type != nil && (*tool.Type == AnthropicToolTypeWebSearch20250305) { + params.Include = []string{"web_search_call.action.sources"} } } - params.Include = []string{"web_search_call.action.sources"} } bifrostReq.Params = params diff --git a/core/providers/openai/openai_test.go b/core/providers/openai/openai_test.go index a30d04d39..15fdbad5b 100644 --- a/core/providers/openai/openai_test.go +++ b/core/providers/openai/openai_test.go @@ -25,7 +25,7 @@ func TestOpenAI(t *testing.T) { testConfig := testutil.ComprehensiveTestConfig{ Provider: schemas.OpenAI, TextModel: "gpt-3.5-turbo-instruct", - ChatModel: "gpt-4o-mini", + ChatModel: "gpt-4o", PromptCachingModel: "gpt-4.1", Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o"}, diff --git a/core/providers/openai/responses_test.go b/core/providers/openai/responses_test.go index 764bd1c1b..c9e88f77f 100644 --- a/core/providers/openai/responses_test.go +++ b/core/providers/openai/responses_test.go @@ -1,6 +1,7 @@ package openai import ( + "encoding/json" "strings" "testing" @@ -344,3 +345,1091 @@ func TestToOpenAIResponsesRequest_GPTOSS_SummaryToContentBlocks(t *testing.T) { }) } } + +// ============================================================================= +// ResponsesToolMessageActionStruct Marshal/Unmarshal Tests +// ============================================================================= + +func TestResponsesToolMessageActionStruct_MarshalUnmarshal_ComputerToolAction(t *testing.T) { + tests := []struct { + name string + action schemas.ResponsesToolMessageActionStruct + jsonData string + }{ + { + name: "computer tool action - click", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: &schemas.ResponsesComputerToolCallAction{ + Type: "click", + X: schemas.Ptr(100), + Y: schemas.Ptr(200), + }, + }, + jsonData: `{"type":"click","x":100,"y":200}`, + }, + { + name: "computer tool action - screenshot", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: &schemas.ResponsesComputerToolCallAction{ + Type: "screenshot", + }, + }, + jsonData: `{"type":"screenshot"}`, + }, + { + name: "computer tool action - type with text", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: &schemas.ResponsesComputerToolCallAction{ + Type: "type", + Text: schemas.Ptr("hello world"), + }, + }, + jsonData: `{"type":"type","text":"hello world"}`, + }, + { + name: "computer tool action - scroll", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: &schemas.ResponsesComputerToolCallAction{ + Type: "scroll", + ScrollX: schemas.Ptr(50), + ScrollY: schemas.Ptr(100), + }, + }, + jsonData: `{"type":"scroll","scroll_x":50,"scroll_y":100}`, + }, + { + name: "computer tool action - zoom with region", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: &schemas.ResponsesComputerToolCallAction{ + Type: "zoom", + Region: []int{0, 0, 1024, 768}, + }, + }, + jsonData: `{"type":"zoom","region":[0,0,1024,768]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name+" - marshal", func(t *testing.T) { + data, err := json.Marshal(tt.action) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + // Unmarshal both to compare as maps (ignoring field order) + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(tt.jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", tt.jsonData, string(data)) + } + }) + + t.Run(tt.name+" - unmarshal", func(t *testing.T) { + var action schemas.ResponsesToolMessageActionStruct + if err := json.Unmarshal([]byte(tt.jsonData), &action); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if action.ResponsesComputerToolCallAction == nil { + t.Fatal("expected ResponsesComputerToolCallAction to be populated") + } + + if action.ResponsesComputerToolCallAction.Type != tt.action.ResponsesComputerToolCallAction.Type { + t.Errorf("type mismatch: expected %s, got %s", + tt.action.ResponsesComputerToolCallAction.Type, + action.ResponsesComputerToolCallAction.Type) + } + + // Verify all other fields are nil (union type should have only one set) + if action.ResponsesWebSearchToolCallAction != nil { + t.Error("expected ResponsesWebSearchToolCallAction to be nil") + } + if action.ResponsesLocalShellToolCallAction != nil { + t.Error("expected ResponsesLocalShellToolCallAction to be nil") + } + if action.ResponsesMCPApprovalRequestAction != nil { + t.Error("expected ResponsesMCPApprovalRequestAction to be nil") + } + }) + } +} + +func TestResponsesToolMessageActionStruct_MarshalUnmarshal_WebSearchAction(t *testing.T) { + tests := []struct { + name string + action schemas.ResponsesToolMessageActionStruct + jsonData string + }{ + { + name: "web search action - search", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Query: schemas.Ptr("golang testing"), + }, + }, + jsonData: `{"type":"search","query":"golang testing"}`, + }, + { + name: "web search action - open_page", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "open_page", + URL: schemas.Ptr("https://example.com"), + }, + }, + jsonData: `{"type":"open_page","url":"https://example.com"}`, + }, + { + name: "web search action - find", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "find", + Pattern: schemas.Ptr("error.*occurred"), + }, + }, + jsonData: `{"type":"find","pattern":"error.*occurred"}`, + }, + { + name: "web search action - search with queries array", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Queries: []string{"query1", "query2"}, + }, + }, + jsonData: `{"type":"search","queries":["query1","query2"]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name+" - marshal", func(t *testing.T) { + data, err := json.Marshal(tt.action) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(tt.jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", tt.jsonData, string(data)) + } + }) + + t.Run(tt.name+" - unmarshal", func(t *testing.T) { + var action schemas.ResponsesToolMessageActionStruct + if err := json.Unmarshal([]byte(tt.jsonData), &action); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if action.ResponsesWebSearchToolCallAction == nil { + t.Fatal("expected ResponsesWebSearchToolCallAction to be populated") + } + + if action.ResponsesWebSearchToolCallAction.Type != tt.action.ResponsesWebSearchToolCallAction.Type { + t.Errorf("type mismatch: expected %s, got %s", + tt.action.ResponsesWebSearchToolCallAction.Type, + action.ResponsesWebSearchToolCallAction.Type) + } + + // Verify all other fields are nil + if action.ResponsesComputerToolCallAction != nil { + t.Error("expected ResponsesComputerToolCallAction to be nil") + } + if action.ResponsesLocalShellToolCallAction != nil { + t.Error("expected ResponsesLocalShellToolCallAction to be nil") + } + if action.ResponsesMCPApprovalRequestAction != nil { + t.Error("expected ResponsesMCPApprovalRequestAction to be nil") + } + }) + } +} + +func TestResponsesToolMessageActionStruct_MarshalUnmarshal_LocalShellAction(t *testing.T) { + tests := []struct { + name string + action schemas.ResponsesToolMessageActionStruct + jsonData string + }{ + { + name: "local shell action - simple exec", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesLocalShellToolCallAction: &schemas.ResponsesLocalShellToolCallAction{ + Type: "exec", + Command: []string{"ls", "-la"}, + Env: []string{"PATH=/usr/bin"}, + }, + }, + jsonData: `{"type":"exec","command":["ls","-la"],"env":["PATH=/usr/bin"]}`, + }, + { + name: "local shell action - with timeout and working directory", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesLocalShellToolCallAction: &schemas.ResponsesLocalShellToolCallAction{ + Type: "exec", + Command: []string{"npm", "test"}, + Env: []string{}, + TimeoutMS: schemas.Ptr(5000), + WorkingDirectory: schemas.Ptr("/home/user/project"), + }, + }, + jsonData: `{"type":"exec","command":["npm","test"],"env":[],"timeout_ms":5000,"working_directory":"/home/user/project"}`, + }, + { + name: "local shell action - with user", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesLocalShellToolCallAction: &schemas.ResponsesLocalShellToolCallAction{ + Type: "exec", + Command: []string{"whoami"}, + Env: []string{}, + User: schemas.Ptr("testuser"), + }, + }, + jsonData: `{"type":"exec","command":["whoami"],"env":[],"user":"testuser"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name+" - marshal", func(t *testing.T) { + data, err := json.Marshal(tt.action) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(tt.jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", tt.jsonData, string(data)) + } + }) + + t.Run(tt.name+" - unmarshal", func(t *testing.T) { + var action schemas.ResponsesToolMessageActionStruct + if err := json.Unmarshal([]byte(tt.jsonData), &action); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if action.ResponsesLocalShellToolCallAction == nil { + t.Fatal("expected ResponsesLocalShellToolCallAction to be populated") + } + + if action.ResponsesLocalShellToolCallAction.Type != "exec" { + t.Errorf("type mismatch: expected exec, got %s", action.ResponsesLocalShellToolCallAction.Type) + } + + // Verify all other fields are nil + if action.ResponsesComputerToolCallAction != nil { + t.Error("expected ResponsesComputerToolCallAction to be nil") + } + if action.ResponsesWebSearchToolCallAction != nil { + t.Error("expected ResponsesWebSearchToolCallAction to be nil") + } + if action.ResponsesMCPApprovalRequestAction != nil { + t.Error("expected ResponsesMCPApprovalRequestAction to be nil") + } + }) + } +} + +func TestResponsesToolMessageActionStruct_MarshalUnmarshal_MCPApprovalAction(t *testing.T) { + tests := []struct { + name string + action schemas.ResponsesToolMessageActionStruct + jsonData string + }{ + { + name: "mcp approval request action", + action: schemas.ResponsesToolMessageActionStruct{ + ResponsesMCPApprovalRequestAction: &schemas.ResponsesMCPApprovalRequestAction{ + ID: "approval-123", + Type: "mcp_approval_request", + Name: "test_tool", + ServerLabel: "test-server", + Arguments: `{"key":"value"}`, + }, + }, + jsonData: `{"id":"approval-123","type":"mcp_approval_request","name":"test_tool","server_label":"test-server","arguments":"{\"key\":\"value\"}"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name+" - marshal", func(t *testing.T) { + data, err := json.Marshal(tt.action) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(tt.jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", tt.jsonData, string(data)) + } + }) + + t.Run(tt.name+" - unmarshal", func(t *testing.T) { + var action schemas.ResponsesToolMessageActionStruct + if err := json.Unmarshal([]byte(tt.jsonData), &action); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if action.ResponsesMCPApprovalRequestAction == nil { + t.Fatal("expected ResponsesMCPApprovalRequestAction to be populated") + } + + if action.ResponsesMCPApprovalRequestAction.Type != "mcp_approval_request" { + t.Errorf("type mismatch: expected mcp_approval_request, got %s", action.ResponsesMCPApprovalRequestAction.Type) + } + + // Verify all other fields are nil + if action.ResponsesComputerToolCallAction != nil { + t.Error("expected ResponsesComputerToolCallAction to be nil") + } + if action.ResponsesWebSearchToolCallAction != nil { + t.Error("expected ResponsesWebSearchToolCallAction to be nil") + } + if action.ResponsesLocalShellToolCallAction != nil { + t.Error("expected ResponsesLocalShellToolCallAction to be nil") + } + }) + } +} + +func TestResponsesToolMessageActionStruct_EdgeCases(t *testing.T) { + t.Run("empty action struct - marshal should error", func(t *testing.T) { + action := schemas.ResponsesToolMessageActionStruct{} + _, err := json.Marshal(action) + if err == nil { + t.Error("expected error when marshaling empty action struct") + } + }) + + t.Run("unknown action type - unmarshal to computer tool (default)", func(t *testing.T) { + jsonData := `{"type":"unknown_action"}` + var action schemas.ResponsesToolMessageActionStruct + if err := json.Unmarshal([]byte(jsonData), &action); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + // Default behavior is to unmarshal to computer tool + if action.ResponsesComputerToolCallAction == nil { + t.Error("expected ResponsesComputerToolCallAction to be populated for unknown type") + } + }) + + t.Run("round trip - computer action", func(t *testing.T) { + original := schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: &schemas.ResponsesComputerToolCallAction{ + Type: "click", + X: schemas.Ptr(150), + Y: schemas.Ptr(250), + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var unmarshaled schemas.ResponsesToolMessageActionStruct + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if unmarshaled.ResponsesComputerToolCallAction == nil { + t.Fatal("expected ResponsesComputerToolCallAction to be populated") + } + if unmarshaled.ResponsesComputerToolCallAction.Type != "click" { + t.Errorf("type mismatch: expected click, got %s", unmarshaled.ResponsesComputerToolCallAction.Type) + } + if unmarshaled.ResponsesComputerToolCallAction.X == nil || *unmarshaled.ResponsesComputerToolCallAction.X != 150 { + t.Errorf("X coordinate mismatch") + } + }) +} + +// ============================================================================= +// ResponsesTool Marshal/Unmarshal Tests +// ============================================================================= + +func TestResponsesTool_MarshalUnmarshal_FunctionTool(t *testing.T) { + tests := []struct { + name string + tool schemas.ResponsesTool + jsonData string + }{ + { + name: "function tool with name and description", + tool: schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("get_weather"), + Description: schemas.Ptr("Get the current weather"), + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Strict: schemas.Ptr(true), + }, + }, + jsonData: `{"type":"function","name":"get_weather","description":"Get the current weather","strict":true}`, + }, + { + name: "function tool with cache control", + tool: schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("search_db"), + Description: schemas.Ptr("Search database"), + CacheControl: &schemas.CacheControl{ + Type: "ephemeral", + }, + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Strict: schemas.Ptr(false), + }, + }, + jsonData: `{"type":"function","name":"search_db","description":"Search database","cache_control":{"type":"ephemeral"},"strict":false}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name+" - marshal", func(t *testing.T) { + data, err := json.Marshal(tt.tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(tt.jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", tt.jsonData, string(data)) + } + }) + + t.Run(tt.name+" - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(tt.jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeFunction { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeFunction, tool.Type) + } + + if tool.ResponsesToolFunction == nil { + t.Fatal("expected ResponsesToolFunction to be populated") + } + + if tool.Name == nil || *tool.Name != *tt.tool.Name { + t.Error("name mismatch") + } + if tool.Description == nil || *tool.Description != *tt.tool.Description { + t.Error("description mismatch") + } + }) + } +} + +func TestResponsesTool_MarshalUnmarshal_FileSearchTool(t *testing.T) { + jsonData := `{"type":"file_search","vector_store_ids":null}` + + t.Run("file search tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFileSearch, + ResponsesToolFileSearch: &schemas.ResponsesToolFileSearch{}, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("file search tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeFileSearch { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeFileSearch, tool.Type) + } + + if tool.ResponsesToolFileSearch == nil { + t.Fatal("expected ResponsesToolFileSearch to be populated") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_ComputerUseTool(t *testing.T) { + jsonData := `{"type":"computer_use_preview","display_height":1080,"display_width":1920,"environment":"browser"}` + + t.Run("computer use preview tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeComputerUsePreview, + ResponsesToolComputerUsePreview: &schemas.ResponsesToolComputerUsePreview{ + DisplayWidth: 1920, + DisplayHeight: 1080, + Environment: "browser", + }, + } + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("computer use preview tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeComputerUsePreview { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeComputerUsePreview, tool.Type) + } + + if tool.ResponsesToolComputerUsePreview == nil { + t.Fatal("expected ResponsesToolComputerUsePreview to be populated") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_WebSearchTool(t *testing.T) { + jsonData := `{"type":"web_search","search_context_size":"medium"}` + + t.Run("web search tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + SearchContextSize: schemas.Ptr("medium"), + }, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("web search tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeWebSearch { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeWebSearch, tool.Type) + } + + if tool.ResponsesToolWebSearch == nil { + t.Fatal("expected ResponsesToolWebSearch to be populated") + } + + if tool.ResponsesToolWebSearch.SearchContextSize == nil || *tool.ResponsesToolWebSearch.SearchContextSize != "medium" { + t.Error("search_context_size mismatch") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_MCPTool(t *testing.T) { + jsonData := `{"type":"mcp","name":"test_mcp_tool","server_label":"mcp-server-1"}` + + t.Run("mcp tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeMCP, + Name: schemas.Ptr("test_mcp_tool"), + ResponsesToolMCP: &schemas.ResponsesToolMCP{ + ServerLabel: "mcp-server-1", + }, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("mcp tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeMCP { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeMCP, tool.Type) + } + + if tool.ResponsesToolMCP == nil { + t.Fatal("expected ResponsesToolMCP to be populated") + } + + if tool.ResponsesToolMCP.ServerLabel != "mcp-server-1" { + t.Error("server_label mismatch") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_CodeInterpreterTool(t *testing.T) { + jsonData := `{"type":"code_interpreter","container":null}` + + t.Run("code interpreter tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeCodeInterpreter, + ResponsesToolCodeInterpreter: &schemas.ResponsesToolCodeInterpreter{}, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("code interpreter tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeCodeInterpreter { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeCodeInterpreter, tool.Type) + } + + if tool.ResponsesToolCodeInterpreter == nil { + t.Fatal("expected ResponsesToolCodeInterpreter to be populated") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_ImageGenerationTool(t *testing.T) { + jsonData := `{"type":"image_generation"}` + + t.Run("image generation tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeImageGeneration, + ResponsesToolImageGeneration: &schemas.ResponsesToolImageGeneration{}, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("image generation tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeImageGeneration { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeImageGeneration, tool.Type) + } + + if tool.ResponsesToolImageGeneration == nil { + t.Fatal("expected ResponsesToolImageGeneration to be populated") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_LocalShellTool(t *testing.T) { + jsonData := `{"type":"local_shell"}` + + t.Run("local shell tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeLocalShell, + ResponsesToolLocalShell: &schemas.ResponsesToolLocalShell{}, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("local shell tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeLocalShell { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeLocalShell, tool.Type) + } + + if tool.ResponsesToolLocalShell == nil { + t.Fatal("expected ResponsesToolLocalShell to be populated") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_CustomTool(t *testing.T) { + jsonData := `{"type":"custom","name":"custom_tool","description":"A custom tool"}` + + t.Run("custom tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeCustom, + Name: schemas.Ptr("custom_tool"), + Description: schemas.Ptr("A custom tool"), + ResponsesToolCustom: &schemas.ResponsesToolCustom{}, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("custom tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeCustom { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeCustom, tool.Type) + } + + if tool.ResponsesToolCustom == nil { + t.Fatal("expected ResponsesToolCustom to be populated") + } + + if tool.Name == nil || *tool.Name != "custom_tool" { + t.Error("name mismatch") + } + if tool.Description == nil || *tool.Description != "A custom tool" { + t.Error("description mismatch") + } + }) +} + +func TestResponsesTool_MarshalUnmarshal_WebSearchPreviewTool(t *testing.T) { + jsonData := `{"type":"web_search_preview","search_context_size":"high"}` + + t.Run("web search preview tool - marshal", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearchPreview, + ResponsesToolWebSearchPreview: &schemas.ResponsesToolWebSearchPreview{ + SearchContextSize: schemas.Ptr("high"), + }, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var expected, actual map[string]interface{} + if err := json.Unmarshal([]byte(jsonData), &expected); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + if err := json.Unmarshal(data, &actual); err != nil { + t.Fatalf("failed to unmarshal actual JSON: %v", err) + } + + if !mapsEqual(expected, actual) { + t.Errorf("marshaled JSON mismatch\nexpected: %s\nactual: %s", jsonData, string(data)) + } + }) + + t.Run("web search preview tool - unmarshal", func(t *testing.T) { + var tool schemas.ResponsesTool + if err := json.Unmarshal([]byte(jsonData), &tool); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if tool.Type != schemas.ResponsesToolTypeWebSearchPreview { + t.Errorf("type mismatch: expected %s, got %s", schemas.ResponsesToolTypeWebSearchPreview, tool.Type) + } + + if tool.ResponsesToolWebSearchPreview == nil { + t.Fatal("expected ResponsesToolWebSearchPreview to be populated") + } + }) +} + +func TestResponsesTool_EdgeCases(t *testing.T) { + t.Run("missing type field - unmarshal should error", func(t *testing.T) { + jsonData := `{"name":"test"}` + var tool schemas.ResponsesTool + err := json.Unmarshal([]byte(jsonData), &tool) + if err == nil { + t.Error("expected error when unmarshaling tool without type field") + } + }) + + t.Run("round trip - function tool with all fields", func(t *testing.T) { + original := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("get_weather"), + Description: schemas.Ptr("Get weather info"), + CacheControl: &schemas.CacheControl{ + Type: "ephemeral", + }, + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Strict: schemas.Ptr(true), + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var unmarshaled schemas.ResponsesTool + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if unmarshaled.Type != schemas.ResponsesToolTypeFunction { + t.Error("type mismatch") + } + if unmarshaled.Name == nil || *unmarshaled.Name != "get_weather" { + t.Error("name mismatch") + } + if unmarshaled.Description == nil || *unmarshaled.Description != "Get weather info" { + t.Error("description mismatch") + } + if unmarshaled.CacheControl == nil || unmarshaled.CacheControl.Type != "ephemeral" { + t.Error("cache_control mismatch") + } + if unmarshaled.ResponsesToolFunction == nil || unmarshaled.ResponsesToolFunction.Strict == nil || !*unmarshaled.ResponsesToolFunction.Strict { + t.Error("strict field mismatch") + } + }) + + t.Run("round trip - web search tool with user location", func(t *testing.T) { + original := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + SearchContextSize: schemas.Ptr("medium"), + UserLocation: &schemas.ResponsesToolWebSearchUserLocation{ + City: schemas.Ptr("San Francisco"), + Country: schemas.Ptr("US"), + Timezone: schemas.Ptr("America/Los_Angeles"), + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var unmarshaled schemas.ResponsesTool + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if unmarshaled.ResponsesToolWebSearch == nil { + t.Fatal("expected ResponsesToolWebSearch to be populated") + } + if unmarshaled.ResponsesToolWebSearch.UserLocation == nil { + t.Fatal("expected UserLocation to be populated") + } + if unmarshaled.ResponsesToolWebSearch.UserLocation.City == nil || *unmarshaled.ResponsesToolWebSearch.UserLocation.City != "San Francisco" { + t.Error("city mismatch") + } + }) + + t.Run("nil embedded struct - should marshal type only", func(t *testing.T) { + tool := schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("test"), + // ResponsesToolFunction is nil + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + + if result["type"] != "function" { + t.Error("type mismatch") + } + if result["name"] != "test" { + t.Error("name mismatch") + } + }) +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +// mapsEqual compares two maps for equality (including nested maps and arrays) +func mapsEqual(a, b map[string]interface{}) bool { + if len(a) != len(b) { + return false + } + + for k, v1 := range a { + v2, ok := b[k] + if !ok { + return false + } + + if !valuesEqual(v1, v2) { + return false + } + } + + return true +} + +// valuesEqual compares two values for equality (handles nested structures) +func valuesEqual(v1, v2 interface{}) bool { + switch val1 := v1.(type) { + case map[string]interface{}: + val2, ok := v2.(map[string]interface{}) + if !ok { + return false + } + return mapsEqual(val1, val2) + + case []interface{}: + val2, ok := v2.([]interface{}) + if !ok { + return false + } + if len(val1) != len(val2) { + return false + } + for i := range val1 { + if !valuesEqual(val1[i], val2[i]) { + return false + } + } + return true + + default: + // For primitives, use direct comparison + return v1 == v2 + } +} diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 97ede3ee7..3e9de964e 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -538,22 +538,6 @@ func (action *ResponsesToolMessageActionStruct) UnmarshalJSON(data []byte) error // Based on the type, unmarshal into the appropriate variant switch typeStruct.Type { - case "click", "double_click", "drag", "keypress", "move", "screenshot", "scroll", "type", "wait", "zoom": - var computerToolCallAction ResponsesComputerToolCallAction - if err := Unmarshal(data, &computerToolCallAction); err != nil { - return fmt.Errorf("failed to unmarshal computer tool call action: %w", err) - } - action.ResponsesComputerToolCallAction = &computerToolCallAction - return nil - - case "search", "open_page", "find": - var webSearchToolCallAction ResponsesWebSearchToolCallAction - if err := Unmarshal(data, &webSearchToolCallAction); err != nil { - return fmt.Errorf("failed to unmarshal web search tool call action: %w", err) - } - action.ResponsesWebSearchToolCallAction = &webSearchToolCallAction - return nil - case "exec": var localShellToolCallAction ResponsesLocalShellToolCallAction if err := Unmarshal(data, &localShellToolCallAction); err != nil { @@ -570,9 +554,32 @@ func (action *ResponsesToolMessageActionStruct) UnmarshalJSON(data []byte) error action.ResponsesMCPApprovalRequestAction = &mcpApprovalRequestAction return nil + case "search", "open_page", "find": + var webSearchToolCallAction ResponsesWebSearchToolCallAction + if err := Unmarshal(data, &webSearchToolCallAction); err != nil { + return fmt.Errorf("failed to unmarshal web search tool call action: %w", err) + } + action.ResponsesWebSearchToolCallAction = &webSearchToolCallAction + return nil + + case "click", "double_click", "drag", "keypress", "move", "screenshot", "scroll", "type", "wait", "zoom": + var computerToolCallAction ResponsesComputerToolCallAction + if err := Unmarshal(data, &computerToolCallAction); err != nil { + return fmt.Errorf("failed to unmarshal computer tool call action: %w", err) + } + action.ResponsesComputerToolCallAction = &computerToolCallAction + return nil + default: - return fmt.Errorf("unknown action type: %s", typeStruct.Type) + // use computer tool, as it can have many possible actions + var computerToolCallAction ResponsesComputerToolCallAction + if err := Unmarshal(data, &computerToolCallAction); err != nil { + return fmt.Errorf("failed to unmarshal computer tool call action: %w", err) + } + action.ResponsesComputerToolCallAction = &computerToolCallAction + return nil } + return fmt.Errorf("unknown action type: %s", typeStruct.Type) } type ResponsesToolMessageOutputStruct struct { diff --git a/tests/integrations/python/tests/test_openai.py b/tests/integrations/python/tests/test_openai.py index 5b0f9dc8c..37f093579 100644 --- a/tests/integrations/python/tests/test_openai.py +++ b/tests/integrations/python/tests/test_openai.py @@ -582,8 +582,7 @@ def test_13_streaming(self, test_config, provider, model, vk_enabled): model=format_provider_model(provider, model), messages=STREAMING_CHAT_MESSAGES, max_tokens=200, - stream=True, - extra_body={"reasoning": {"effort": "high"}} + stream=True ) content, chunk_count, tool_calls_detected = collect_streaming_content( @@ -3165,7 +3164,7 @@ def test_56_web_search_wildcard_domains(self, provider, model, vk_enabled): model=format_provider_model(provider, model), tools=[{ "type": "web_search", - "allowed_domains": ["wikipedia.org/*", "*.edu"] + "allowed_domains": ["wikipedia.org", "en.wikipedia.org"] }], input="What is machine learning use web search tool?", include=["web_search_call.action.sources"],