diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 9f2da6f1c..40850ad4f 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -58,6 +58,7 @@ type TestScenarios struct { CountTokens bool // Count tokens functionality ChatAudio bool // Chat completion with audio input/output functionality StructuredOutputs bool // Structured outputs (JSON schema) functionality + WebSearchTool bool // Web search tool functionality } // ComprehensiveTestConfig extends TestConfig with additional scenarios diff --git a/core/internal/testutil/tests.go b/core/internal/testutil/tests.go index eb6c2dbe3..97a1d3681 100644 --- a/core/internal/testutil/tests.go +++ b/core/internal/testutil/tests.go @@ -33,6 +33,7 @@ func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context RunMultipleToolCallsTest, RunEnd2EndToolCallingTest, RunAutomaticFunctionCallingTest, + RunWebSearchToolTest, RunImageURLTest, RunImageBase64Test, RunMultipleImagesTest, @@ -105,6 +106,7 @@ func printTestSummary(t *testing.T, testConfig ComprehensiveTestConfig) { {"FileBase64", testConfig.Scenarios.FileBase64}, {"FileURL", testConfig.Scenarios.FileURL}, {"CompleteEnd2End", testConfig.Scenarios.CompleteEnd2End}, + {"WebSearchTool", testConfig.Scenarios.WebSearchTool}, {"SpeechSynthesis", testConfig.Scenarios.SpeechSynthesis}, {"SpeechSynthesisStream", testConfig.Scenarios.SpeechSynthesisStream}, {"Transcription", testConfig.Scenarios.Transcription}, diff --git a/core/internal/testutil/web_search_tool.go b/core/internal/testutil/web_search_tool.go new file mode 100644 index 000000000..9f3420df0 --- /dev/null +++ b/core/internal/testutil/web_search_tool.go @@ -0,0 +1,169 @@ +package testutil + +import ( + "context" + "os" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// This test verifies that the web search tool is properly invoked and returns results +func RunWebSearchToolTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.WebSearchTool { + t.Logf("Web search tool not supported for provider %s", testConfig.Provider) + return + } + + t.Run("WebSearchTool", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Create a simple query that should trigger web search + responsesMessages := []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("What is the current weather in New York City?"), + } + + // Create web search tool for Responses API + webSearchTool := &schemas.ResponsesTool{ + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + UserLocation: &schemas.ResponsesToolWebSearchUserLocation{ + Type: bifrost.Ptr("approximate"), + Country: bifrost.Ptr("US"), + City: bifrost.Ptr("New York"), + }, + }, + } + + // Use specialized web search retry configuration + retryConfig := WebSearchRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "WebSearchTool", + ExpectedBehavior: map[string]interface{}{ + "expected_tool_type": "web_search", + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + }, + } + + // Create expectations for web search + expectations := WebSearchExpectations() + + // Create operation for Responses API + responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: responsesMessages, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{*webSearchTool}, + }, + Fallbacks: testConfig.Fallbacks, + } + + return client.ResponsesRequest(bfCtx, responsesReq) + } + + // Execute test with retry - Responses API only for web search + response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchTool", responsesOperation) + + // Validate success + if err != nil { + t.Fatalf("❌ WebSearchTool test failed: %s", GetErrorMessage(err)) + } + + require.NotNil(t, response, "Response should not be nil") + + // Validate web search was invoked + webSearchCallFound := false + hasTextResponse := false + + if response.Output != nil { + for _, output := range response.Output { + // Check for web_search_call + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall { + webSearchCallFound = true + t.Logf("✅ Found web_search_call in output") + + // Validate the search action + if output.ResponsesToolMessage != nil && output.ResponsesToolMessage.Action != nil { + action := output.ResponsesToolMessage.Action + if action.ResponsesWebSearchToolCallAction != nil { + query := action.ResponsesWebSearchToolCallAction.Query + if query != nil { + t.Logf("✅ Web search query: %s", *query) + } + + // Validate sources if present + if len(action.ResponsesWebSearchToolCallAction.Sources) > 0 { + t.Logf("✅ Found %d search result sources", len(action.ResponsesWebSearchToolCallAction.Sources)) + + // Log first few sources + for i, source := range action.ResponsesWebSearchToolCallAction.Sources { + if i >= 3 { + break + } + t.Logf(" Source %d: %s", i+1, source.URL) + } + } + } + } + } + + // Check for text response (message with actual answer) + if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage { + if output.Content != nil && len(output.Content.ContentBlocks) > 0 { + for _, block := range output.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + hasTextResponse = true + + // Check for citations + if block.ResponsesOutputMessageContentText != nil && len(block.ResponsesOutputMessageContentText.Annotations) > 0 { + t.Logf("✅ Found %d citations in response", len(block.ResponsesOutputMessageContentText.Annotations)) + } else { + t.Logf("✅ Found text response") + } + } + } + } + } + } + } + + require.True(t, webSearchCallFound, "Web search call should be present in response output") + require.True(t, hasTextResponse, "Response should contain text answer based on web search results") + + t.Logf("🎉 WebSearchTool test passed!") + }) +} + +// WebSearchRetryConfig returns specialized retry configuration for web search tests +func WebSearchRetryConfig() ResponsesRetryConfig { + return ResponsesRetryConfig{ + MaxAttempts: 5, + BaseDelay: 2 * time.Second, + MaxDelay: 10 * time.Second, + Conditions: []ResponsesRetryCondition{ + &ResponsesEmptyCondition{}, + &ResponsesGenericResponseCondition{}, + }, + OnRetry: func(attempt int, reason string, t *testing.T) { + t.Logf("🔄 Retrying web search test (attempt %d): %s", attempt, reason) + }, + } +} + +// WebSearchExpectations returns validation expectations for web search responses +func WebSearchExpectations() ResponseExpectations { + return ResponseExpectations{ + ShouldHaveContent: true, + } +} diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 087041e5e..09baa1a3e 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -2093,7 +2093,7 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return ToAnthropicResponsesRequest(request) }, + func() (any, error) { return ToAnthropicResponsesRequest(ctx, request) }, provider.GetProviderKey(), ) if bifrostErr != nil { diff --git a/core/providers/anthropic/anthropic_test.go b/core/providers/anthropic/anthropic_test.go index 056580ca5..d8d1a9952 100644 --- a/core/providers/anthropic/anthropic_test.go +++ b/core/providers/anthropic/anthropic_test.go @@ -42,6 +42,7 @@ func TestAnthropic(t *testing.T) { MultipleToolCalls: true, End2EndToolCalling: true, AutomaticFunctionCall: true, + WebSearchTool: true, ImageURL: true, ImageBase64: true, MultipleImages: true, diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index 85f407569..89e511ab2 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -17,27 +17,33 @@ import ( // AnthropicResponsesStreamState tracks state during streaming conversion for responses API type AnthropicResponsesStreamState struct { - ChunkIndex *int // index of the chunk in the stream - AccumulatedJSON string // deltas of any event + ChunkIndex *int // index of the chunk in the stream (reused for computer AND web search) + AccumulatedJSON string // deltas of any event (reused for computer AND web search) // Computer tool accumulation ComputerToolID *string + // Web search tool accumulation (minimal fields) + WebSearchToolID *string // Tool ID of active web search + WebSearchOutputIndex *int // Output index for this search + WebSearchResult *AnthropicContentBlock // Result block when it arrives + // OpenAI Responses API mapping state - ContentIndexToOutputIndex map[int]int // Maps Anthropic content_index to OpenAI output_index - ToolArgumentBuffers map[int]string // Maps output_index to accumulated tool argument JSON - MCPCallOutputIndices map[int]bool // Tracks which output indices are MCP calls - ItemIDs map[int]string // Maps output_index to item ID for stable IDs - ReasoningSignatures map[int]string // Maps output_index to reasoning signature - TextContentIndices map[int]bool // Tracks which content indices are text blocks - ReasoningContentIndices map[int]bool // Tracks which content indices are reasoning blocks - CurrentOutputIndex int // Current output index counter - MessageID *string // Message ID from message_start - Model *string // Model name from message_start - StopReason *string // Stop reason for the message - CreatedAt int // Timestamp for created_at consistency - HasEmittedCreated bool // Whether we've emitted response.created - HasEmittedInProgress bool // Whether we've emitted response.in_progress + ContentIndexToOutputIndex map[int]int // Maps Anthropic content_index to OpenAI output_index + ContentIndexToBlockType map[int]AnthropicContentBlockType // Tracks content block types + ToolArgumentBuffers map[int]string // Maps output_index to accumulated tool argument JSON + MCPCallOutputIndices map[int]bool // Tracks which output indices are MCP calls + ItemIDs map[int]string // Maps output_index to item ID for stable IDs + ReasoningSignatures map[int]string // Maps output_index to reasoning signature + TextContentIndices map[int]bool // Tracks which content indices are text blocks + ReasoningContentIndices map[int]bool // Tracks which content indices are reasoning blocks + CurrentOutputIndex int // Current output index counter + MessageID *string // Message ID from message_start + Model *string // Model name from message_start + StopReason *string // Stop reason for the message + CreatedAt int // Timestamp for created_at consistency + HasEmittedCreated bool // Whether we've emitted response.created + HasEmittedInProgress bool // Whether we've emitted response.in_progress } // anthropicResponsesStreamStatePool provides a pool for Anthropic responses stream state objects. @@ -73,6 +79,11 @@ func acquireAnthropicResponsesStreamState() *AnthropicResponsesStreamState { } else { clear(state.ContentIndexToOutputIndex) } + if state.ContentIndexToBlockType == nil { + state.ContentIndexToBlockType = make(map[int]AnthropicContentBlockType) + } else { + clear(state.ContentIndexToBlockType) + } if state.ToolArgumentBuffers == nil { state.ToolArgumentBuffers = make(map[int]string) } else { @@ -107,6 +118,9 @@ func acquireAnthropicResponsesStreamState() *AnthropicResponsesStreamState { state.ChunkIndex = nil state.AccumulatedJSON = "" state.ComputerToolID = nil + state.WebSearchToolID = nil + state.WebSearchOutputIndex = nil + state.WebSearchResult = nil state.CurrentOutputIndex = 0 state.MessageID = nil state.StopReason = nil @@ -130,7 +144,11 @@ func (state *AnthropicResponsesStreamState) flush() { state.ChunkIndex = nil state.AccumulatedJSON = "" state.ComputerToolID = nil + state.WebSearchToolID = nil + state.WebSearchOutputIndex = nil + state.WebSearchResult = nil state.ContentIndexToOutputIndex = make(map[int]int) + state.ContentIndexToBlockType = make(map[int]AnthropicContentBlockType) state.ToolArgumentBuffers = make(map[int]string) state.MCPCallOutputIndices = make(map[int]bool) state.ItemIDs = make(map[int]string) @@ -252,6 +270,99 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, }}, nil, false } + // Handle web_search server_tool_use (query block) + if chunk.ContentBlock.Type == AnthropicContentBlockTypeServerToolUse && + chunk.ContentBlock.Name != nil && + *chunk.ContentBlock.Name == string(AnthropicToolNameWebSearch) && + chunk.ContentBlock.ID != nil { + + // Start accumulating web search query (reuse shared accumulation fields) + state.ChunkIndex = chunk.Index + state.AccumulatedJSON = "" + state.WebSearchToolID = chunk.ContentBlock.ID + // Store output index value (allocate new int to avoid pointer-to-local-variable issue) + state.WebSearchOutputIndex = schemas.Ptr(outputIndex) + + // Store item ID + state.ItemIDs[outputIndex] = *chunk.ContentBlock.ID + + // Emit output_item.added for web_search_call + item := &schemas.ResponsesMessage{ + ID: chunk.ContentBlock.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeWebSearchCall), + Status: schemas.Ptr("in_progress"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: chunk.ContentBlock.ID, + Action: &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + }, + }, + }, + } + + var responses []*schemas.BifrostResponsesStreamResponse + + // Emit output_item.added + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }) + + // Emit web_search_call.in_progress + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeWebSearchCallInProgress, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ItemID: chunk.ContentBlock.ID, + }) + + // Emit web_search_call.searching + responses = append(responses, &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeWebSearchCallSearching, + SequenceNumber: sequenceNumber + len(responses), + OutputIndex: schemas.Ptr(outputIndex), + ItemID: chunk.ContentBlock.ID, + }) + + return responses, nil, false + } + + // Handle web_search_tool_result block (results arrive) + if chunk.ContentBlock.Type == AnthropicContentBlockTypeWebSearchToolResult && + chunk.ContentBlock.ToolUseID != nil { + + // Track that this content index is a web search result block + if chunk.Index != nil { + state.ContentIndexToBlockType[*chunk.Index] = AnthropicContentBlockTypeWebSearchToolResult + } + + // Check if this matches our active web search + if state.WebSearchToolID != nil && *state.WebSearchToolID == *chunk.ContentBlock.ToolUseID { + + // Store the result block (arrives complete with all sources) + state.WebSearchResult = chunk.ContentBlock + + if chunk.Index != nil { + delete(state.ContentIndexToBlockType, *chunk.Index) + } + + // Emit web_search_call.completed + return []*schemas.BifrostResponsesStreamResponse{{ + Type: schemas.ResponsesStreamResponseTypeWebSearchCallCompleted, + SequenceNumber: sequenceNumber, + OutputIndex: state.WebSearchOutputIndex, + ItemID: chunk.ContentBlock.ToolUseID, + }}, nil, false + } + + // If no matching tool ID, skip (shouldn't happen in normal flow) + return nil, nil, false + } + switch chunk.ContentBlock.Type { case AnthropicContentBlockTypeText: // Text block - emit output_item.added with type "message" @@ -480,10 +591,9 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, case AnthropicStreamDeltaTypeInputJSON: // Function call arguments delta if chunk.Delta.PartialJSON != nil { - // Check if we're accumulating a computer tool - if state.ComputerToolID != nil && - state.ChunkIndex != nil && - *state.ChunkIndex == *chunk.Index { + // Check if we're accumulating any tool (computer or web search) + // Both use the shared ChunkIndex and AccumulatedJSON fields + if state.ChunkIndex != nil && *state.ChunkIndex == *chunk.Index { // Accumulate the JSON and don't emit anything state.AccumulatedJSON += *chunk.Delta.PartialJSON return nil, nil, false @@ -554,6 +664,28 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, return []*schemas.BifrostResponsesStreamResponse{response}, nil, false } return nil, nil, false + + case AnthropicStreamDeltaTypeCitations: + // Handle citations delta - convert Anthropic citation to OpenAI annotation + if chunk.Delta.Citation != nil { + // For streaming, we don't compute indices yet (pass empty string) + annotation := convertAnthropicCitationToAnnotation(*chunk.Delta.Citation, "") + + // Emit output_text.annotation.added event + itemID := state.ItemIDs[outputIndex] + response := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputTextAnnotationAdded, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Annotation: &annotation, + } + if itemID != "" { + response.ItemID = &itemID + } + return []*schemas.BifrostResponsesStreamResponse{response}, nil, false + } + return nil, nil, false } } @@ -562,59 +694,160 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, if chunk.Index != nil { outputIndex := state.getOrCreateOutputIndex(chunk.Index) - // Check if this is the end of a computer tool accumulation - if state.ComputerToolID != nil && - state.ChunkIndex != nil && - *state.ChunkIndex == *chunk.Index { + // Check if this is the end of a tool accumulation (computer or web search query) + if state.ChunkIndex != nil && *state.ChunkIndex == *chunk.Index { + + // Computer tool completion + if state.ComputerToolID != nil { + // Parse accumulated JSON and convert to OpenAI format + var inputMap map[string]interface{} + var action *schemas.ResponsesComputerToolCallAction + + if state.AccumulatedJSON != "" { + if err := json.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { + action = convertAnthropicToResponsesComputerAction(inputMap) + } + } + + // Create computer_call item with action + statusCompleted := "completed" + item := &schemas.ResponsesMessage{ + ID: state.ComputerToolID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeComputerCall), + Status: &statusCompleted, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: state.ComputerToolID, + ResponsesComputerToolCall: &schemas.ResponsesComputerToolCall{ + PendingSafetyChecks: []schemas.ResponsesComputerToolCallPendingSafetyCheck{}, + }, + }, + } + + // Add action if we successfully parsed it + if action != nil { + item.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesComputerToolCallAction: action, + } + } - // Parse accumulated JSON and convert to OpenAI format - var inputMap map[string]interface{} - var action *schemas.ResponsesComputerToolCallAction + // Clear computer tool state + state.ComputerToolID = nil + state.ChunkIndex = nil + state.AccumulatedJSON = "" + + // Return output_item.done + return []*schemas.BifrostResponsesStreamResponse{ + { + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + SequenceNumber: sequenceNumber, + OutputIndex: schemas.Ptr(outputIndex), + ContentIndex: chunk.Index, + Item: item, + }, + }, nil, false + } + // Web search query block ended (don't emit output_item.done yet - wait for result) + if state.WebSearchToolID != nil { + // Clear ChunkIndex (done accumulating query) + // Keep WebSearchToolID, WebSearchOutputIndex, and AccumulatedJSON (need them for final item) + state.ChunkIndex = nil + return nil, nil, false + } + } + + // Check if this is the end of a web_search_tool_result block + if state.WebSearchResult != nil && state.WebSearchToolID != nil { + + // Parse the query from AccumulatedJSON + var query string + var queries []string if state.AccumulatedJSON != "" { + var inputMap map[string]interface{} if err := json.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { - action = convertAnthropicToResponsesComputerAction(inputMap) + if q, ok := inputMap["query"].(string); ok { + query = q + queries = []string{q} + } } } - // Create computer_call item with action + // Extract sources from the result block + var sources []schemas.ResponsesWebSearchToolCallActionSearchSource + if state.WebSearchResult.Content != nil && len(state.WebSearchResult.Content.ContentBlocks) > 0 { + for _, resultBlock := range state.WebSearchResult.Content.ContentBlocks { + if resultBlock.Type == AnthropicContentBlockTypeWebSearchResult && resultBlock.URL != nil { + sources = append(sources, schemas.ResponsesWebSearchToolCallActionSearchSource{ + Type: "url", + URL: *resultBlock.URL, + Title: resultBlock.Title, + EncryptedContent: resultBlock.EncryptedContent, + PageAge: resultBlock.PageAge, + }) + } + } + } + + // Create complete web_search_call item with action including query and sources statusCompleted := "completed" + action := &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Sources: sources, + } + // Only set query fields if query is not empty + if query != "" { + action.Query = &query + action.Queries = queries + } + item := &schemas.ResponsesMessage{ - ID: state.ComputerToolID, - Type: schemas.Ptr(schemas.ResponsesMessageTypeComputerCall), + ID: state.WebSearchToolID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeWebSearchCall), Status: &statusCompleted, ResponsesToolMessage: &schemas.ResponsesToolMessage{ - CallID: state.ComputerToolID, - ResponsesComputerToolCall: &schemas.ResponsesComputerToolCall{ - PendingSafetyChecks: []schemas.ResponsesComputerToolCallPendingSafetyCheck{}, + CallID: state.WebSearchToolID, + Action: &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: action, }, }, } - // Add action if we successfully parsed it - if action != nil { - item.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesComputerToolCallAction: action, - } - } + outputIdx := state.WebSearchOutputIndex - // Clear computer tool state - state.ComputerToolID = nil - state.ChunkIndex = nil + // Clear all web search state + state.WebSearchToolID = nil + state.WebSearchOutputIndex = nil + state.WebSearchResult = nil state.AccumulatedJSON = "" - // Return output_item.done + if chunk.Index != nil { + delete(state.ContentIndexToBlockType, *chunk.Index) + } + + // Return output_item.done for the web_search_call (not the result block) return []*schemas.BifrostResponsesStreamResponse{ { Type: schemas.ResponsesStreamResponseTypeOutputItemDone, SequenceNumber: sequenceNumber, - OutputIndex: schemas.Ptr(outputIndex), + OutputIndex: outputIdx, ContentIndex: chunk.Index, Item: item, }, }, nil, false } + // Skip generic output_item.done if this is a web_search_tool_result block + // (the web search handler above already emitted the proper done event) + if chunk.Index != nil { + if blockType, exists := state.ContentIndexToBlockType[*chunk.Index]; exists { + if blockType == AnthropicContentBlockTypeWebSearchToolResult { + // Clean up the tracking + delete(state.ContentIndexToBlockType, *chunk.Index) + return nil, nil, false + } + } + } + // Check if this is a text block - emit output_text.done and content_part.done var responses []*schemas.BifrostResponsesStreamResponse itemID := state.ItemIDs[outputIndex] @@ -844,7 +1077,7 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, } // ToAnthropicResponsesStreamResponse converts a Bifrost Responses stream response to Anthropic SSE string format -func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schemas.BifrostResponsesStreamResponse) []*AnthropicStreamEvent { +func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp *schemas.BifrostResponsesStreamResponse) []*AnthropicStreamEvent { if bifrostResp == nil { return nil } @@ -916,6 +1149,30 @@ func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schema // Always start with empty input for streaming compatibility contentBlock.Input = map[string]interface{}{} + streamResp.ContentBlock = contentBlock + } else if bifrostResp.Item != nil && + bifrostResp.Item.Type != nil && + *bifrostResp.Item.Type == schemas.ResponsesMessageTypeWebSearchCall { + + // Web search call - emit content_block_start with server_tool_use + streamResp.Type = AnthropicStreamEventTypeContentBlockStart + + if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } + + // Build the content_block as server_tool_use + contentBlock := &AnthropicContentBlock{ + Type: AnthropicContentBlockTypeServerToolUse, + ID: bifrostResp.Item.ID, // The tool use ID + Name: schemas.Ptr(string(AnthropicToolNameWebSearch)), // "web_search" + } + + // Start with empty input for streaming compatibility + contentBlock.Input = map[string]interface{}{} + streamResp.ContentBlock = contentBlock } else { // Text or other content blocks - emit content_block_start @@ -1049,35 +1306,24 @@ func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schema shouldGenerateDeltas = true } } - } - - // Sanitize websearch tool arguments to remove both allowed_domains and blocked_domains - // Anthropic only allows one or the other, not both - if shouldGenerateDeltas && argumentsJSON != "" { - // Check if this is a websearch tool - if bifrostResp.Item.ResponsesToolMessage.Name != nil && - *bifrostResp.Item.ResponsesToolMessage.Name == "WebSearch" { - // Parse the JSON to check for conflicting domain filters - var toolArgs map[string]interface{} - if err := json.Unmarshal([]byte(argumentsJSON), &toolArgs); err == nil { - _, hasAllowed := toolArgs["allowed_domains"] - _, hasBlocked := toolArgs["blocked_domains"] - - // If both domain filters exist, remove blocked_domains and keep allowed_domains - // This prioritizes the allowed list over the blocked list - if hasAllowed && hasBlocked { - delete(toolArgs, "blocked_domains") - - // Re-marshal the sanitized arguments - if sanitizedBytes, err := json.Marshal(toolArgs); err == nil { - argumentsJSON = string(sanitizedBytes) - } - } + case schemas.ResponsesMessageTypeWebSearchCall: + // Extract query from web search action + if bifrostResp.Item.ResponsesToolMessage.Action != nil && + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil && + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query != nil { + // Create input map with query + inputMap := map[string]interface{}{ + "query": *bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query, + } + if jsonBytes, err := json.Marshal(inputMap); err == nil { + argumentsJSON = string(jsonBytes) + shouldGenerateDeltas = true } } + } + if shouldGenerateDeltas && argumentsJSON != "" { // Generate synthetic input_json_delta events by chunking the JSON - // Use OutputIndex for proper Anthropic indexing, fallback to ContentIndex var indexToUse *int if bifrostResp.OutputIndex != nil { indexToUse = bifrostResp.OutputIndex @@ -1162,67 +1408,77 @@ func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schema } } + case schemas.ResponsesStreamResponseTypeOutputTextAnnotationAdded: + // Convert OpenAI annotation to Anthropic citation + if bifrostResp.Annotation != nil { + streamResp.Type = AnthropicStreamEventTypeContentBlockDelta + if bifrostResp.OutputIndex != nil { + streamResp.Index = bifrostResp.OutputIndex + } else if bifrostResp.ContentIndex != nil { + streamResp.Index = bifrostResp.ContentIndex + } + + citation := convertAnnotationToAnthropicCitation(*bifrostResp.Annotation) + + streamResp.Delta = &AnthropicStreamDelta{ + Type: AnthropicStreamDeltaTypeCitations, + Citation: &citation, + } + } + case schemas.ResponsesStreamResponseTypeContentPartDone: return nil case schemas.ResponsesStreamResponseTypeOutputItemDone: // Handle WebSearch tool completion with sanitization and synthetic delta generation - if bifrostResp.Item != nil && - bifrostResp.Item.Type != nil && - *bifrostResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall && - bifrostResp.Item.ResponsesToolMessage != nil && - bifrostResp.Item.ResponsesToolMessage.Name != nil && - *bifrostResp.Item.ResponsesToolMessage.Name == "WebSearch" && - bifrostResp.Item.ResponsesToolMessage.Arguments != nil { - - argumentsJSON := *bifrostResp.Item.ResponsesToolMessage.Arguments - - // Parse the arguments JSON - var toolArgs map[string]interface{} - if err := json.Unmarshal([]byte(argumentsJSON), &toolArgs); err == nil { - _, hasAllowed := toolArgs["allowed_domains"] - _, hasBlocked := toolArgs["blocked_domains"] - - // If both domain filters exist, remove blocked_domains and keep allowed_domains - if hasAllowed && hasBlocked { - delete(toolArgs, "blocked_domains") - - // Re-marshal the sanitized arguments - if sanitizedBytes, err := json.Marshal(toolArgs); err == nil { - argumentsJSON = string(sanitizedBytes) + + // check for claude-cli user agent + if ctx != nil { + if userAgent, ok := ctx.Value(schemas.BifrostContextKeyUserAgent).(string); ok { + if strings.Contains(userAgent, "claude-cli") { + // check for WebSearch tool + if bifrostResp.Item != nil && + bifrostResp.Item.Type != nil && + *bifrostResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall && + bifrostResp.Item.ResponsesToolMessage != nil && + bifrostResp.Item.ResponsesToolMessage.Name != nil && + *bifrostResp.Item.ResponsesToolMessage.Name == "WebSearch" && + bifrostResp.Item.ResponsesToolMessage.Arguments != nil { + + argumentsJSON := sanitizeWebSearchArguments(*bifrostResp.Item.ResponsesToolMessage.Arguments) bifrostResp.Item.ResponsesToolMessage.Arguments = &argumentsJSON - } - } - } - // Generate synthetic input_json_delta events for the sanitized WebSearch arguments - // This replaces the delta events that were skipped earlier - var events []*AnthropicStreamEvent + // Generate synthetic input_json_delta events for the sanitized WebSearch arguments + // This replaces the delta events that were skipped earlier + var events []*AnthropicStreamEvent - // Use OutputIndex for proper Anthropic indexing, fallback to ContentIndex - var indexToUse *int - if bifrostResp.OutputIndex != nil { - indexToUse = bifrostResp.OutputIndex - } else if bifrostResp.ContentIndex != nil { - indexToUse = bifrostResp.ContentIndex - } + // Use OutputIndex for proper Anthropic indexing, fallback to ContentIndex + var indexToUse *int + if bifrostResp.OutputIndex != nil { + indexToUse = bifrostResp.OutputIndex + } else if bifrostResp.ContentIndex != nil { + indexToUse = bifrostResp.ContentIndex + } - deltaEvents := generateSyntheticInputJSONDeltas(argumentsJSON, indexToUse) - events = append(events, deltaEvents...) + deltaEvents := generateSyntheticInputJSONDeltas(argumentsJSON, indexToUse) + events = append(events, deltaEvents...) - // Add the content_block_stop event at the end - stopEvent := &AnthropicStreamEvent{ - Type: AnthropicStreamEventTypeContentBlockStop, - Index: indexToUse, - } - events = append(events, stopEvent) + // Add the content_block_stop event at the end + stopEvent := &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockStop, + Index: indexToUse, + } + events = append(events, stopEvent) - // Clean up the tracking for this WebSearch item - if bifrostResp.Item.ID != nil { - webSearchItemIDs.Delete(*bifrostResp.Item.ID) - } + // Clean up the tracking for this WebSearch item + if bifrostResp.Item.ID != nil { + webSearchItemIDs.Delete(*bifrostResp.Item.ID) + } - return events + return events + } + } + } } if bifrostResp.Item != nil && @@ -1258,6 +1514,80 @@ func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schema } } } + } else if bifrostResp.Item != nil && + bifrostResp.Item.Type != nil && + *bifrostResp.Item.Type == schemas.ResponsesMessageTypeWebSearchCall { + + // Web search call complete - emit content_block_stop for query, then web_search_tool_result block + var events []*AnthropicStreamEvent + + // 1. Emit content_block_stop for the query block (server_tool_use) + stopEvent := &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockStop, + } + if bifrostResp.ContentIndex != nil { + stopEvent.Index = bifrostResp.ContentIndex + } else if bifrostResp.OutputIndex != nil { + stopEvent.Index = bifrostResp.OutputIndex + } + events = append(events, stopEvent) + + // 2. Extract sources and create web_search_tool_result block if sources exist + if bifrostResp.Item.ResponsesToolMessage != nil && + bifrostResp.Item.ResponsesToolMessage.Action != nil && + bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil && + len(bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Sources) > 0 { + + // Calculate next index for result block + var resultIndex *int + if bifrostResp.OutputIndex != nil { + nextIdx := *bifrostResp.OutputIndex + 1 + resultIndex = &nextIdx + } else if bifrostResp.ContentIndex != nil { + nextIdx := *bifrostResp.ContentIndex + 1 + resultIndex = &nextIdx + } + + // Create content blocks for each source + var resultContentBlocks []AnthropicContentBlock + for _, source := range bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Sources { + block := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeWebSearchResult, + URL: &source.URL, + EncryptedContent: source.EncryptedContent, + PageAge: source.PageAge, + } + if source.Title != nil { + block.Title = source.Title + } else if source.URL != "" { + block.Title = schemas.Ptr(source.URL) + } + resultContentBlocks = append(resultContentBlocks, block) + } + + // Emit content_block_start for web_search_tool_result + resultStartEvent := &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockStart, + Index: resultIndex, + ContentBlock: &AnthropicContentBlock{ + Type: AnthropicContentBlockTypeWebSearchToolResult, + ToolUseID: bifrostResp.Item.ID, // Link to the server_tool_use block + Content: &AnthropicContent{ + ContentBlocks: resultContentBlocks, + }, + }, + } + events = append(events, resultStartEvent) + + // Emit content_block_stop for the result block + resultStopEvent := &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockStop, + Index: resultIndex, + } + events = append(events, resultStopEvent) + } + + return events } else { // For text blocks and other content blocks, emit content_block_stop streamResp.Type = AnthropicStreamEventTypeContentBlockStop @@ -1268,6 +1598,13 @@ func ToAnthropicResponsesStreamResponse(ctx context.Context, bifrostResp *schema streamResp.Index = bifrostResp.ContentIndex } } + case schemas.ResponsesStreamResponseTypeWebSearchCallInProgress, + schemas.ResponsesStreamResponseTypeWebSearchCallSearching, + schemas.ResponsesStreamResponseTypeWebSearchCallCompleted: + // Web search lifecycle events - these are OpenAI-style events that don't have Anthropic equivalents + // Skip them to avoid cluttering the stream + return nil + case schemas.ResponsesStreamResponseTypePing: streamResp.Type = AnthropicStreamEventTypePing @@ -1465,6 +1802,8 @@ func (request *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx context.Co break } } + + params.Include = []string{"web_search_call.action.sources"} } bifrostReq.Params = params @@ -1520,7 +1859,7 @@ func (request *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx context.Co } // ToAnthropicResponsesRequest converts a BifrostRequest with Responses structure back to AnthropicMessageRequest -func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*AnthropicMessageRequest, error) { +func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostResponsesRequest) (*AnthropicMessageRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost request is nil") } @@ -1547,7 +1886,31 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (* } } if bifrostReq.Params.Text != nil { - anthropicReq.OutputFormat = convertResponsesTextConfigToAnthropicOutputFormat(bifrostReq.Params.Text) + // Citations cannot be used together with Structured Outputs in anthropic. + hasCitationsEnabled := false + // loop over input messages and check if any message has citations enabled + for _, message := range bifrostReq.Input { + if message.Content == nil || message.Content.ContentBlocks == nil { + continue + } + if message.Content.ContentBlocks != nil { + for _, block := range message.Content.ContentBlocks { + if block.Type == schemas.ResponsesInputMessageContentBlockTypeFile && + block.Citations != nil && + block.Citations.Enabled != nil && + *block.Citations.Enabled { + hasCitationsEnabled = true + break + } + } + } + if hasCitationsEnabled { + break + } + } + if !hasCitationsEnabled { + anthropicReq.OutputFormat = convertResponsesTextConfigToAnthropicOutputFormat(bifrostReq.Params.Text) + } } if bifrostReq.Params.Reasoning != nil { if bifrostReq.Params.Reasoning.MaxTokens != nil { @@ -1629,7 +1992,7 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (* } if bifrostReq.Input != nil { - anthropicMessages, systemContent := ConvertBifrostMessagesToAnthropicMessages(bifrostReq.Input) + anthropicMessages, systemContent := ConvertBifrostMessagesToAnthropicMessages(ctx, bifrostReq.Input) // Set system message if present if systemContent != nil { @@ -1710,7 +2073,7 @@ func (response *AnthropicMessageResponse) ToBifrostResponsesResponse() *schemas. } // ToAnthropicResponsesResponse converts a BifrostResponse with Responses structure back to AnthropicMessageResponse -func ToAnthropicResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *AnthropicMessageResponse { +func ToAnthropicResponsesResponse(ctx *schemas.BifrostContext, bifrostResp *schemas.BifrostResponsesResponse) *AnthropicMessageResponse { anthropicResp := &AnthropicMessageResponse{ Type: "message", Role: "assistant", @@ -1737,7 +2100,7 @@ func ToAnthropicResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) // Convert output messages to Anthropic content blocks using the new conversion method var contentBlocks []AnthropicContentBlock if bifrostResp.Output != nil { - anthropicMessages, _ := ConvertBifrostMessagesToAnthropicMessages(bifrostResp.Output) + anthropicMessages, _ := ConvertBifrostMessagesToAnthropicMessages(ctx, bifrostResp.Output) // Extract content blocks from the converted messages for _, msg := range anthropicMessages { if msg.Content.ContentBlocks != nil { @@ -1799,7 +2162,7 @@ func ConvertAnthropicMessagesToBifrostMessages(anthropicMessages []AnthropicMess // ConvertBifrostMessagesToAnthropicMessages converts an array of Bifrost ResponsesMessage to Anthropic message format // This is the main conversion method from Bifrost to Anthropic - handles all message types and returns messages + system content -func ConvertBifrostMessagesToAnthropicMessages(bifrostMessages []schemas.ResponsesMessage) ([]AnthropicMessage, *AnthropicContent) { +func ConvertBifrostMessagesToAnthropicMessages(ctx *schemas.BifrostContext, bifrostMessages []schemas.ResponsesMessage) ([]AnthropicMessage, *AnthropicContent) { var anthropicMessages []AnthropicMessage var systemContent *AnthropicContent var pendingToolCalls []AnthropicContentBlock @@ -1977,7 +2340,7 @@ func ConvertBifrostMessagesToAnthropicMessages(bifrostMessages []schemas.Respons pendingReasoningContentBlocks = nil } - toolUseBlock := convertBifrostFunctionCallToAnthropicToolUse(&msg) + toolUseBlock := convertBifrostFunctionCallToAnthropicToolUse(ctx, &msg) if toolUseBlock != nil { // If there was a previous assistant message (text only) that was just added, // and we have no pending tool calls yet, we should merge the tool call into it. @@ -2161,10 +2524,44 @@ func ConvertBifrostMessagesToAnthropicMessages(bifrostMessages []schemas.Respons } } + case schemas.ResponsesMessageTypeWebSearchCall: + // Flush any pending tool results before processing web search calls + flushPendingToolResults() + + // Web search calls need special handling: create server_tool_use + web_search_tool_result blocks + webSearchBlocks := convertBifrostWebSearchCallToAnthropicBlocks(&msg) + if len(webSearchBlocks) > 0 { + // For web search, we create both server_tool_use and web_search_tool_result + // These should appear in an assistant message + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } + } + + // Prepend any pending reasoning blocks to ensure they come BEFORE tool blocks + if len(pendingReasoningContentBlocks) > 0 { + copied := make([]AnthropicContentBlock, len(pendingReasoningContentBlocks)) + copy(copied, pendingReasoningContentBlocks) + pendingToolCalls = append(copied, pendingToolCalls...) + pendingReasoningContentBlocks = nil + } + + // Add the web search blocks (server_tool_use + web_search_tool_result) + pendingToolCalls = append(pendingToolCalls, webSearchBlocks...) + + // Track the tool call ID for the server_tool_use block (first block) + if len(webSearchBlocks) > 0 && webSearchBlocks[0].ID != nil { + if currentToolCallIDs == nil { + currentToolCallIDs = make(map[string]bool) + } + currentToolCallIDs[*webSearchBlocks[0].ID] = true + } + } + // Handle other tool call types that are not natively supported by Anthropic case schemas.ResponsesMessageTypeFileSearchCall, schemas.ResponsesMessageTypeCodeInterpreterCall, - schemas.ResponsesMessageTypeWebSearchCall, schemas.ResponsesMessageTypeLocalShellCall, schemas.ResponsesMessageTypeCustomToolCall, schemas.ResponsesMessageTypeImageGenerationCall: @@ -2458,6 +2855,13 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant } } + case AnthropicContentBlockTypeServerToolUse: + // Accumulate server tool use blocks + if block.ID != nil && block.Name != nil { + blockCopy := block + pendingToolUseBlocks = append(pendingToolUseBlocks, &blockCopy) + } + case AnthropicContentBlockTypeMCPToolUse: // Accumulate MCP tool use blocks if block.ID != nil && block.Name != nil { @@ -2468,6 +2872,12 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant case AnthropicContentBlockTypeMCPToolResult: // Handle MCP tool results directly without flushing other blocks // MCP results will be emitted as separate messages + + case AnthropicContentBlockTypeWebSearchResult: + // Find the corresponding web_search_call by tool_use_id and attach sources + if block.ToolUseID != nil { + attachWebSearchSourcesToCall(bifrostMessages, *block.ToolUseID, block, true) + } } } @@ -2510,6 +2920,21 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), } } + } else if toolBlock.Name != nil && *toolBlock.Name == string(AnthropicToolNameWebSearch) { + // Handle web_search tool + bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeWebSearchCall) + bifrostMsg.ResponsesToolMessage.Name = nil + if inputMap, ok := toolBlock.Input.(map[string]interface{}); ok { + if query, ok := inputMap["query"].(string); ok { + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Query: schemas.Ptr(query), + Queries: []string{query}, // Anthropic uses single query + }, + } + } + } } else { bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(toolBlock.Input)) } @@ -2534,18 +2959,34 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC var bifrostMsg schemas.ResponsesMessage if isOutputMessage { // For output messages, use ContentBlocks with ResponsesOutputMessageContentTypeText + contentBlock := schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: block.Text, + CacheControl: block.CacheControl, + } + + // Convert Anthropic citations to OpenAI annotations + if block.Citations != nil && len(block.Citations.TextCitations) > 0 { + annotations := make([]schemas.ResponsesOutputMessageContentTextAnnotation, len(block.Citations.TextCitations)) + fullText := "" + if block.Text != nil { + fullText = *block.Text + } + for i, citation := range block.Citations.TextCitations { + annotations[i] = convertAnthropicCitationToAnnotation(citation, fullText) + } + + contentBlock.ResponsesOutputMessageContentText = &schemas.ResponsesOutputMessageContentText{ + Annotations: annotations, + } + } + bifrostMsg = schemas.ResponsesMessage{ ID: schemas.Ptr("msg_" + providerUtils.GetRandomString(50)), Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: block.Text, - CacheControl: block.CacheControl, - }, - }, + ContentBlocks: []schemas.ResponsesMessageContentBlock{contentBlock}, }, } } else { @@ -2689,6 +3130,42 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC bifrostMessages = append(bifrostMessages, bifrostMsg) } } + + case AnthropicContentBlockTypeServerToolUse: + // Check if it's a web_search tool + if block.Name != nil && *block.Name == string(AnthropicToolNameWebSearch) { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeWebSearchCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{}, + } + + // Extract query from input + if block.Input != nil { + if inputMap, ok := block.Input.(map[string]interface{}); ok { + if query, ok := inputMap["query"].(string); ok { + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Query: schemas.Ptr(query), + Queries: []string{query}, // Anthropic uses single query + }, + } + } + } + } + + if isOutputMessage { + bifrostMsg.ID = block.ID + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + } + + case AnthropicContentBlockTypeWebSearchToolResult: + // Find the corresponding web_search_call by tool_use_id + if block.ToolUseID != nil { + attachWebSearchSourcesToCall(bifrostMessages, *block.ToolUseID, block, true) + } case AnthropicContentBlockTypeMCPToolUse: // Convert MCP tool use to MCP call (assistant's tool call) if block.ID != nil && block.Name != nil { @@ -2894,7 +3371,7 @@ func convertBifrostReasoningToAnthropicThinking(msg *schemas.ResponsesMessage) [ } // convertBifrostFunctionCallToAnthropicToolUse converts a Bifrost function call to Anthropic tool use -func convertBifrostFunctionCallToAnthropicToolUse(msg *schemas.ResponsesMessage) *AnthropicContentBlock { +func convertBifrostFunctionCallToAnthropicToolUse(ctx *schemas.BifrostContext, msg *schemas.ResponsesMessage) *AnthropicContentBlock { if msg.ResponsesToolMessage != nil { toolUseBlock := AnthropicContentBlock{ Type: AnthropicContentBlockTypeToolUse, @@ -2913,19 +3390,12 @@ func convertBifrostFunctionCallToAnthropicToolUse(msg *schemas.ResponsesMessage) // Sanitize WebSearch tool arguments to remove both allowed_domains and blocked_domains // Anthropic only allows one or the other, not both - if msg.ResponsesToolMessage.Name != nil && *msg.ResponsesToolMessage.Name == "WebSearch" { - var toolArgs map[string]interface{} - if err := json.Unmarshal([]byte(argumentsJSON), &toolArgs); err == nil { - _, hasAllowed := toolArgs["allowed_domains"] - _, hasBlocked := toolArgs["blocked_domains"] - - // If both domain filters exist, remove blocked_domains and keep allowed_domains - if hasAllowed && hasBlocked { - delete(toolArgs, "blocked_domains") - - // Re-marshal the sanitized arguments - if sanitizedBytes, err := json.Marshal(toolArgs); err == nil { - argumentsJSON = string(sanitizedBytes) + // Only do this for Claude CLI + if ctx != nil { + if userAgent, ok := ctx.Value(schemas.BifrostContextKeyUserAgent).(string); ok { + if strings.Contains(userAgent, "claude-cli") { + if msg.ResponsesToolMessage.Name != nil && *msg.ResponsesToolMessage.Name == "WebSearch" { + argumentsJSON = sanitizeWebSearchArguments(argumentsJSON) } } } @@ -3132,6 +3602,70 @@ func convertBifrostMCPApprovalToAnthropicToolUse(msg *schemas.ResponsesMessage) return nil } +// convertBifrostWebSearchCallToAnthropicBlocks converts a Bifrost web_search_call to Anthropic server_tool_use and web_search_tool_result blocks +func convertBifrostWebSearchCallToAnthropicBlocks(msg *schemas.ResponsesMessage) []AnthropicContentBlock { + if msg.ResponsesToolMessage == nil || msg.ResponsesToolMessage.Action == nil || msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction == nil { + return nil + } + + var blocks []AnthropicContentBlock + action := msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction + + // 1. Create server_tool_use block for the web search + serverToolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeServerToolUse, + Name: schemas.Ptr("web_search"), + } + + if msg.ID != nil { + serverToolUseBlock.ID = msg.ID + } + + // Extract the query from the action + if action.Query != nil { + input := map[string]interface{}{ + "query": *action.Query, + } + serverToolUseBlock.Input = input + } + + blocks = append(blocks, serverToolUseBlock) + + // 2. Create web_search_tool_result block if sources are present + if len(action.Sources) > 0 { + var resultBlocks []AnthropicContentBlock + for _, source := range action.Sources { + if source.URL != "" { + resultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeWebSearchResult, + URL: schemas.Ptr(source.URL), + EncryptedContent: source.EncryptedContent, + PageAge: source.PageAge, + } + if source.Title != nil { + resultBlock.Title = source.Title + } else if source.URL != "" { + resultBlock.Title = schemas.Ptr(source.URL) + } + resultBlocks = append(resultBlocks, resultBlock) + } + } + + if len(resultBlocks) > 0 { + webSearchResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeWebSearchToolResult, + ToolUseID: msg.ID, + Content: &AnthropicContent{ + ContentBlocks: resultBlocks, + }, + } + blocks = append(blocks, webSearchResultBlock) + } + } + + return blocks +} + // convertBifrostUnsupportedToolCallToAnthropicMessage converts unsupported tool calls to text messages func convertBifrostUnsupportedToolCallToAnthropicMessage(msg *schemas.ResponsesMessage, msgType schemas.ResponsesMessageType) *AnthropicMessage { if msg.ResponsesToolMessage != nil { @@ -3230,14 +3764,17 @@ func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { case AnthropicToolTypeWebSearch20250305: bifrostTool := &schemas.ResponsesTool{ Type: schemas.ResponsesToolTypeWebSearch, - Name: &tool.Name, } if tool.AnthropicToolWebSearch != nil { bifrostTool.ResponsesToolWebSearch = &schemas.ResponsesToolWebSearch{ Filters: &schemas.ResponsesToolWebSearchFilters{ AllowedDomains: tool.AnthropicToolWebSearch.AllowedDomains, + BlockedDomains: tool.AnthropicToolWebSearch.BlockedDomains, }, } + if tool.AnthropicToolWebSearch.MaxUses != nil { + bifrostTool.ResponsesToolWebSearch.MaxUses = tool.AnthropicToolWebSearch.MaxUses + } if tool.AnthropicToolWebSearch.UserLocation != nil { bifrostTool.ResponsesToolWebSearch.UserLocation = &schemas.ResponsesToolWebSearchUserLocation{ Type: tool.AnthropicToolWebSearch.UserLocation.Type, @@ -3247,6 +3784,7 @@ func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { } } } + return bifrostTool case AnthropicToolTypeBash20250124: @@ -3417,8 +3955,12 @@ func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool) *A AnthropicToolWebSearch: &AnthropicToolWebSearch{}, } if tool.ResponsesToolWebSearch != nil { + if tool.ResponsesToolWebSearch.MaxUses != nil { + anthropicTool.AnthropicToolWebSearch.MaxUses = tool.ResponsesToolWebSearch.MaxUses + } if tool.ResponsesToolWebSearch.Filters != nil { anthropicTool.AnthropicToolWebSearch.AllowedDomains = tool.ResponsesToolWebSearch.Filters.AllowedDomains + anthropicTool.AnthropicToolWebSearch.BlockedDomains = tool.ResponsesToolWebSearch.Filters.BlockedDomains } if tool.ResponsesToolWebSearch.UserLocation != nil { anthropicTool.AnthropicToolWebSearch.UserLocation = &AnthropicToolWebSearchUserLocation{ @@ -3536,12 +4078,22 @@ func convertResponsesToolChoiceToAnthropic(toolChoice *schemas.ResponsesToolChoi func convertContentBlockToAnthropic(block schemas.ResponsesMessageContentBlock) *AnthropicContentBlock { switch block.Type { case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: + anthropicBlock := AnthropicContentBlock{} if block.Text != nil { - return &AnthropicContentBlock{ + anthropicBlock = AnthropicContentBlock{ Type: AnthropicContentBlockTypeText, Text: block.Text, CacheControl: block.CacheControl, } + if block.ResponsesOutputMessageContentText != nil && len(block.ResponsesOutputMessageContentText.Annotations) > 0 { + anthropicBlock.Citations = &AnthropicCitations{ + TextCitations: make([]AnthropicTextCitation, len(block.ResponsesOutputMessageContentText.Annotations)), + } + for i, annotation := range block.ResponsesOutputMessageContentText.Annotations { + anthropicBlock.Citations.TextCitations[i] = convertAnnotationToAnthropicCitation(annotation) + } + } + return &anthropicBlock } case schemas.ResponsesInputMessageContentBlockTypeImage: if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { @@ -3562,6 +4114,7 @@ func convertContentBlockToAnthropic(block schemas.ResponsesMessageContentBlock) anthropicBlock := ConvertResponsesFileBlockToAnthropic( block.ResponsesInputMessageContentBlockFile, block.CacheControl, + block.Citations, ) return &anthropicBlock } @@ -3611,6 +4164,10 @@ func (block AnthropicContentBlock) toBifrostResponsesDocumentBlock() schemas.Res ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{}, } + if block.Citations != nil && block.Citations.Config != nil { + resultBlock.Citations = block.Citations.Config + } + // Set filename from title if available if block.Title != nil { resultBlock.ResponsesInputMessageContentBlockFile.Filename = block.Title @@ -3718,6 +4275,140 @@ func convertBifrostMCPToolToAnthropicServer(tool *schemas.ResponsesTool) *Anthro return mcpServer } +// convertAnthropicCitationToAnnotation converts an Anthropic citation to an OpenAI annotation +// fullText is the complete text content of the message block, used to compute citation indices for web search results +func convertAnthropicCitationToAnnotation(citation AnthropicTextCitation, fullText string) schemas.ResponsesOutputMessageContentTextAnnotation { + annotation := schemas.ResponsesOutputMessageContentTextAnnotation{ + Type: string(citation.Type), + Index: citation.DocumentIndex, + Text: schemas.Ptr(citation.CitedText), + } + + // Map type-specific fields based on citation type + switch citation.Type { + case AnthropicCitationTypeCharLocation: + // Character location fields + annotation.StartCharIndex = citation.StartCharIndex + annotation.EndCharIndex = citation.EndCharIndex + annotation.Filename = citation.DocumentTitle + annotation.FileID = citation.FileID + + case AnthropicCitationTypePageLocation: + // Page location fields + annotation.StartPageNumber = citation.StartPageNumber + annotation.EndPageNumber = citation.EndPageNumber + annotation.Filename = citation.DocumentTitle + annotation.FileID = citation.FileID + + case AnthropicCitationTypeContentBlockLocation: + // Content block location fields + annotation.StartBlockIndex = citation.StartBlockIndex + annotation.EndBlockIndex = citation.EndBlockIndex + annotation.Filename = citation.DocumentTitle + annotation.FileID = citation.FileID + + case AnthropicCitationTypeWebSearchResultLocation: + // Web search result fields - map to OpenAI url_citation format + annotation.Type = "url_citation" + annotation.Title = citation.Title + annotation.URL = citation.URL + annotation.EncryptedIndex = citation.EncryptedIndex + + // Compute start_index and end_index by findin + if fullText != "" && citation.URL != nil && *citation.URL != "" { + startIdx := strings.Index(fullText, *citation.URL) + if startIdx != -1 { + endIdx := startIdx + len(*citation.URL) + annotation.StartIndex = schemas.Ptr(startIdx) + annotation.EndIndex = schemas.Ptr(endIdx) + } else { + // assign start_index and end_index to the entire text + annotation.StartIndex = schemas.Ptr(0) + annotation.EndIndex = schemas.Ptr(len(fullText)) + } + } + + case AnthropicCitationTypeSearchResultLocation: + // Search result location fields + annotation.StartBlockIndex = citation.StartBlockIndex + annotation.EndBlockIndex = citation.EndBlockIndex + annotation.Title = citation.Title + annotation.Source = citation.Source + } + + return annotation +} + +// convertAnnotationToAnthropicCitation converts an OpenAI annotation to an Anthropic citation +func convertAnnotationToAnthropicCitation(annotation schemas.ResponsesOutputMessageContentTextAnnotation) AnthropicTextCitation { + citation := AnthropicTextCitation{ + Type: AnthropicCitationType(annotation.Type), + CitedText: "", + } + + // Map common fields + if annotation.Text != nil { + citation.CitedText = *annotation.Text + } + + // Map type-specific fields based on annotation type + switch annotation.Type { + case string(AnthropicCitationTypeCharLocation): + // Character location + citation.StartCharIndex = annotation.StartCharIndex + citation.EndCharIndex = annotation.EndCharIndex + citation.DocumentTitle = annotation.Filename + citation.DocumentIndex = annotation.Index + citation.FileID = annotation.FileID + + case string(AnthropicCitationTypePageLocation): + // Page location + citation.StartPageNumber = annotation.StartPageNumber + citation.EndPageNumber = annotation.EndPageNumber + citation.DocumentTitle = annotation.Filename + citation.DocumentIndex = annotation.Index + citation.FileID = annotation.FileID + + case string(AnthropicCitationTypeContentBlockLocation): + // Content block location + citation.StartBlockIndex = annotation.StartBlockIndex + citation.EndBlockIndex = annotation.EndBlockIndex + citation.DocumentTitle = annotation.Filename + citation.DocumentIndex = annotation.Index + citation.FileID = annotation.FileID + + case string(AnthropicCitationTypeWebSearchResultLocation): + // Web search result + citation.Title = annotation.Title + citation.URL = annotation.URL + citation.EncryptedIndex = annotation.EncryptedIndex + + case string(AnthropicCitationTypeSearchResultLocation): + // Search result location + citation.StartBlockIndex = annotation.StartBlockIndex + citation.EndBlockIndex = annotation.EndBlockIndex + citation.Title = annotation.Title + citation.Source = annotation.Source + + case "url_citation": + citation.Type = AnthropicCitationTypeWebSearchResultLocation + citation.URL = annotation.URL + citation.Title = annotation.Title + citation.EncryptedIndex = annotation.EncryptedIndex + + case "file_citation", "container_file_citation", "file_path", "text_annotation": + // OpenAI native types - map to char_location + citation.Type = "char_location" + citation.StartCharIndex = annotation.StartIndex + citation.EndCharIndex = annotation.EndIndex + citation.DocumentTitle = annotation.Filename + citation.Title = annotation.Title + citation.FileID = annotation.FileID + } + + return citation +} + // convertResponsesToAnthropicComputerAction converts ResponsesComputerToolCallAction to Anthropic input map func convertResponsesToAnthropicComputerAction(action *schemas.ResponsesComputerToolCallAction) map[string]any { input := map[string]any{} diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 47f7c515f..d6be1cf68 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -195,23 +195,32 @@ func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { return nil } + // Try to unmarshal as a single ContentBlock object (e.g., web_search_tool_result_error) + // If successful, wrap it in an array + var singleBlock AnthropicContentBlock + if err := sonic.Unmarshal(data, &singleBlock); err == nil && singleBlock.Type != "" { + mc.ContentBlocks = []AnthropicContentBlock{singleBlock} + return nil + } + return fmt.Errorf("content field is neither a string nor an array of ContentBlock") } type AnthropicContentBlockType string const ( - AnthropicContentBlockTypeText AnthropicContentBlockType = "text" - AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" - AnthropicContentBlockTypeDocument AnthropicContentBlockType = "document" - AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" - AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" - AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" - AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" - AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" - AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" - AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" - AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" + AnthropicContentBlockTypeText AnthropicContentBlockType = "text" + AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" + AnthropicContentBlockTypeDocument AnthropicContentBlockType = "document" + AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" + AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" + AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" + AnthropicContentBlockTypeWebSearchToolResult AnthropicContentBlockType = "web_search_tool_result" + AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" + AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" + AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" + AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" + AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" ) // AnthropicContentBlock represents content in Anthropic message format @@ -516,6 +525,7 @@ const ( AnthropicStreamDeltaTypeInputJSON AnthropicStreamDeltaType = "input_json_delta" AnthropicStreamDeltaTypeThinking AnthropicStreamDeltaType = "thinking_delta" AnthropicStreamDeltaTypeSignature AnthropicStreamDeltaType = "signature_delta" + AnthropicStreamDeltaTypeCitations AnthropicStreamDeltaType = "citations_delta" ) // AnthropicStreamDelta represents incremental updates to content blocks during streaming (legacy) @@ -525,6 +535,7 @@ type AnthropicStreamDelta struct { PartialJSON *string `json:"partial_json,omitempty"` Thinking *string `json:"thinking,omitempty"` Signature *string `json:"signature,omitempty"` + Citation *AnthropicTextCitation `json:"citation,omitempty"` // For citations_delta StopReason *AnthropicStopReason `json:"stop_reason,omitempty"` // only not present in "message_start" events StopSequence *string `json:"stop_sequence"` } diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index cb166c080..bab94abde 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -1,7 +1,6 @@ package anthropic import ( - "context" "encoding/json" "fmt" "strings" @@ -28,7 +27,7 @@ var ( } ) -func getRequestBodyForResponses(ctx context.Context, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { var jsonBody []byte var err error @@ -54,7 +53,7 @@ func getRequestBodyForResponses(ctx context.Context, request *schemas.BifrostRes } } else { // Convert request to Anthropic format - reqBody, err := ToAnthropicResponsesRequest(request) + reqBody, err := ToAnthropicResponsesRequest(ctx, request) if err != nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerName) } @@ -157,6 +156,10 @@ func ConvertToAnthropicDocumentBlock(block schemas.ChatContentBlock) AnthropicCo Source: &AnthropicSource{}, } + if block.Citations != nil { + documentBlock.Citations = &AnthropicCitations{Config: block.Citations} + } + if block.File == nil { return documentBlock } @@ -223,13 +226,17 @@ func ConvertToAnthropicDocumentBlock(block schemas.ChatContentBlock) AnthropicCo } // ConvertResponsesFileBlockToAnthropic converts a Responses file block directly to Anthropic document format -func ConvertResponsesFileBlockToAnthropic(fileBlock *schemas.ResponsesInputMessageContentBlockFile, cacheControl *schemas.CacheControl) AnthropicContentBlock { +func ConvertResponsesFileBlockToAnthropic(fileBlock *schemas.ResponsesInputMessageContentBlockFile, cacheControl *schemas.CacheControl, citations *schemas.Citations) AnthropicContentBlock { documentBlock := AnthropicContentBlock{ Type: AnthropicContentBlockTypeDocument, CacheControl: cacheControl, Source: &AnthropicSource{}, } + if citations != nil { + documentBlock.Citations = &AnthropicCitations{Config: citations} + } + if fileBlock == nil { return documentBlock } @@ -784,3 +791,118 @@ func convertAnthropicOutputFormatToResponsesTextConfig(outputFormat interface{}) Format: format, } } + +// sanitizeWebSearchArguments sanitizes WebSearch tool arguments by removing conflicting domain filters. +// Anthropic only allows one of allowed_domains or blocked_domains, not both. +// This function handles empty and non-empty arrays: +// - If one array is empty, delete that one +// - If both arrays are filled, delete blocked_domains +// - If both arrays are empty, delete blocked_domains +func sanitizeWebSearchArguments(argumentsJSON string) string { + var toolArgs map[string]interface{} + if err := json.Unmarshal([]byte(argumentsJSON), &toolArgs); err != nil { + return argumentsJSON // Return original if parse fails + } + + allowedVal, hasAllowed := toolArgs["allowed_domains"] + blockedVal, hasBlocked := toolArgs["blocked_domains"] + + // Only process if both fields exist + if hasAllowed && hasBlocked { + // Helper function to check if array is empty + isEmptyArray := func(val interface{}) bool { + if arr, ok := val.([]interface{}); ok { + return len(arr) == 0 + } + return false + } + + allowedEmpty := isEmptyArray(allowedVal) + blockedEmpty := isEmptyArray(blockedVal) + + var shouldDelete string + if allowedEmpty && !blockedEmpty { + // Delete allowed_domains if it's empty and blocked is not + shouldDelete = "allowed_domains" + } else if blockedEmpty && !allowedEmpty { + // Delete blocked_domains if it's empty and allowed is not + shouldDelete = "blocked_domains" + } else { + // Both are filled or both are empty: delete blocked_domains + shouldDelete = "blocked_domains" + } + + delete(toolArgs, shouldDelete) + + // Re-marshal the sanitized arguments + if sanitizedBytes, err := json.Marshal(toolArgs); err == nil { + return string(sanitizedBytes) + } + } + + return argumentsJSON +} + +// attachWebSearchSourcesToCall finds a web_search_call by tool_use_id and attaches sources to it. +// It searches backwards through bifrostMessages to find the matching call and updates its action. +func attachWebSearchSourcesToCall(bifrostMessages []schemas.ResponsesMessage, toolUseID string, resultBlock AnthropicContentBlock, includeExtendedFields bool) { + // Search backwards to find matching web_search_call + for i := len(bifrostMessages) - 1; i >= 0; i-- { + msg := &bifrostMessages[i] + if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeWebSearchCall && + msg.ID != nil && + *msg.ID == toolUseID { + + if msg.ResponsesToolMessage == nil { + msg.ResponsesToolMessage = &schemas.ResponsesToolMessage{} + } + + // Found the matching web_search_call, add sources + if resultBlock.Content != nil && len(resultBlock.Content.ContentBlocks) > 0 { + sources := extractWebSearchSources(resultBlock.Content.ContentBlocks, includeExtendedFields) + + // Initialize action if needed + if msg.ResponsesToolMessage.Action == nil { + msg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{} + } + if msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction == nil { + msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction = &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + } + } + msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Sources = sources + } + break + } + } +} + +// extractWebSearchSources extracts search sources from Anthropic content blocks. +// When includeExtendedFields is true, it includes EncryptedContent, PageAge, and Title fields. +func extractWebSearchSources(contentBlocks []AnthropicContentBlock, includeExtendedFields bool) []schemas.ResponsesWebSearchToolCallActionSearchSource { + sources := make([]schemas.ResponsesWebSearchToolCallActionSearchSource, 0, len(contentBlocks)) + + for _, result := range contentBlocks { + if result.Type == AnthropicContentBlockTypeWebSearchResult && result.URL != nil { + source := schemas.ResponsesWebSearchToolCallActionSearchSource{ + Type: "url", + URL: *result.URL, + } + + if includeExtendedFields { + source.EncryptedContent = result.EncryptedContent + source.PageAge = result.PageAge + + if result.Title != nil { + source.Title = result.Title + } else { + source.Title = schemas.Ptr(*result.URL) + } + } + + sources = append(sources, source) + } + } + + return sources +} diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go index 03c7fad1f..c12d8c662 100644 --- a/core/providers/azure/utils.go +++ b/core/providers/azure/utils.go @@ -1,7 +1,6 @@ package azure import ( - "context" "fmt" "github.com/bytedance/sonic" @@ -10,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { var jsonBody []byte var err error @@ -39,7 +38,7 @@ func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.B } else { // Convert request to Anthropic format request.Model = deployment - reqBody, err := anthropic.ToAnthropicResponsesRequest(request) + reqBody, err := anthropic.ToAnthropicResponsesRequest(ctx, request) if err != nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerName) } diff --git a/core/providers/openai/openai_test.go b/core/providers/openai/openai_test.go index 076a90c7d..a30d04d39 100644 --- a/core/providers/openai/openai_test.go +++ b/core/providers/openai/openai_test.go @@ -50,6 +50,7 @@ func TestOpenAI(t *testing.T) { MultipleToolCalls: true, End2EndToolCalling: true, AutomaticFunctionCall: true, + WebSearchTool: true, ImageURL: true, ImageBase64: true, MultipleImages: true, diff --git a/core/providers/openai/responses.go b/core/providers/openai/responses.go index edcdac3cf..8184a1090 100644 --- a/core/providers/openai/responses.go +++ b/core/providers/openai/responses.go @@ -196,9 +196,45 @@ func (req *OpenAIResponsesRequest) filterUnsupportedTools() { if supportedTypes[tool.Type] { // check for computer use preview if tool.Type == schemas.ResponsesToolTypeComputerUsePreview && tool.ResponsesToolComputerUsePreview != nil && tool.ResponsesToolComputerUsePreview.EnableZoom != nil { - // create new tool and assign it to the filtered tools newTool := tool - newTool.ResponsesToolComputerUsePreview.EnableZoom = nil + newComputerUse := &schemas.ResponsesToolComputerUsePreview{ + DisplayHeight: tool.ResponsesToolComputerUsePreview.DisplayHeight, + DisplayWidth: tool.ResponsesToolComputerUsePreview.DisplayWidth, + Environment: tool.ResponsesToolComputerUsePreview.Environment, + // EnableZoom is intentionally omitted (nil) - OpenAI doesn't support it + } + newTool.ResponsesToolComputerUsePreview = newComputerUse + filteredTools = append(filteredTools, newTool) + } else if tool.Type == schemas.ResponsesToolTypeWebSearch && tool.ResponsesToolWebSearch != nil { + // Create a proper deep copy with new nested pointers to avoid mutating the original + newTool := tool + newWebSearch := &schemas.ResponsesToolWebSearch{} + + // MaxUses is intentionally omitted (nil) - OpenAI doesn't support it + + // Handle Filters: OpenAI doesn't support BlockedDomains + if tool.ResponsesToolWebSearch.Filters != nil { + hasAllowedDomains := len(tool.ResponsesToolWebSearch.Filters.AllowedDomains) > 0 + + if hasAllowedDomains { + // Keep only AllowedDomains (copy the slice to avoid sharing) + newWebSearch.Filters = &schemas.ResponsesToolWebSearchFilters{ + AllowedDomains: append([]string(nil), tool.ResponsesToolWebSearch.Filters.AllowedDomains...), + // BlockedDomains is intentionally omitted - OpenAI doesn't support it + } + } + // If only blocked domains or both empty, Filters stays nil + } + + // Copy other fields if they exist + if tool.ResponsesToolWebSearch.UserLocation != nil { + newWebSearch.UserLocation = tool.ResponsesToolWebSearch.UserLocation + } + if tool.ResponsesToolWebSearch.SearchContextSize != nil { + newWebSearch.SearchContextSize = tool.ResponsesToolWebSearch.SearchContextSize + } + + newTool.ResponsesToolWebSearch = newWebSearch filteredTools = append(filteredTools, newTool) } else { filteredTools = append(filteredTools, tool) diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 916efc1e9..226f4b692 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -116,10 +116,11 @@ func (r *OpenAIChatRequest) MarshalJSON() ([]byte, error) { contentCopy := *msg.Content contentCopy.ContentBlocks = make([]schemas.ChatContentBlock, len(msg.Content.ContentBlocks)) for j, block := range msg.Content.ContentBlocks { - needsBlockCopy := block.CacheControl != nil || (block.File != nil && block.File.FileType != nil) + needsBlockCopy := block.CacheControl != nil || block.Citations != nil || (block.File != nil && block.File.FileType != nil) if needsBlockCopy { blockCopy := block blockCopy.CacheControl = nil + blockCopy.Citations = nil // Strip FileType and FileURL from file block if blockCopy.File != nil && (blockCopy.File.FileType != nil || blockCopy.File.FileURL != nil) { fileCopy := *blockCopy.File @@ -291,17 +292,33 @@ func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { // Copy only this message messagesCopy[i] = msg - // Strip CacheControl and FileType from content blocks if needed + // Strip CacheControl, FileType, and filter unsupported citation types from content blocks if needed if msg.Content != nil && msg.Content.ContentBlocks != nil { contentCopy := *msg.Content contentCopy.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, len(msg.Content.ContentBlocks)) hasContentModification := false for j, block := range msg.Content.ContentBlocks { - needsBlockCopy := block.CacheControl != nil || (block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileType != nil) + needsBlockCopy := block.CacheControl != nil || block.Citations != nil || (block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileType != nil) || (block.ResponsesOutputMessageContentText != nil && len(block.ResponsesOutputMessageContentText.Annotations) > 0) if needsBlockCopy { hasContentModification = true blockCopy := block blockCopy.CacheControl = nil + blockCopy.Citations = nil + + // Filter out unsupported citation types from annotations + if blockCopy.ResponsesOutputMessageContentText != nil && len(blockCopy.ResponsesOutputMessageContentText.Annotations) > 0 { + textCopy := *blockCopy.ResponsesOutputMessageContentText + filteredAnnotations := filterSupportedAnnotations(textCopy.Annotations) + if len(filteredAnnotations) > 0 { + textCopy.Annotations = filteredAnnotations + blockCopy.ResponsesOutputMessageContentText = &textCopy + } else { + // If no supported annotations remain, remove the annotations array + textCopy.Annotations = nil + blockCopy.ResponsesOutputMessageContentText = &textCopy + } + } + // Strip FileType from file block if blockCopy.ResponsesInputMessageContentBlockFile != nil && blockCopy.ResponsesInputMessageContentBlockFile.FileType != nil { fileCopy := *blockCopy.ResponsesInputMessageContentBlockFile @@ -318,19 +335,54 @@ func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { } } - // Strip CacheControl and FileType from tool message output blocks if needed - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Output != nil { - if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + // Strip unsupported fields from tool message + if msg.ResponsesToolMessage != nil { + toolMsgCopy := *msg.ResponsesToolMessage + toolMsgModified := false + + // Strip unsupported fields from web search sources + if msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil { + sources := msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Sources + if len(sources) > 0 { + needsSourceCopy := false + for _, source := range sources { + if source.Title != nil || source.EncryptedContent != nil || source.PageAge != nil { + needsSourceCopy = true + break + } + } + + if needsSourceCopy { + actionCopy := *msg.ResponsesToolMessage.Action + webSearchActionCopy := *msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction + strippedSources := make([]schemas.ResponsesWebSearchToolCallActionSearchSource, len(sources)) + for j, source := range sources { + // Only keep Type and URL for OpenAI + strippedSources[j] = schemas.ResponsesWebSearchToolCallActionSearchSource{ + Type: source.Type, + URL: source.URL, + // Title, EncryptedContent, and PageAge are omitted + } + } + webSearchActionCopy.Sources = strippedSources + actionCopy.ResponsesWebSearchToolCallAction = &webSearchActionCopy + toolMsgCopy.Action = &actionCopy + toolMsgModified = true + } + } + } + + // Strip CacheControl and FileType from tool message output blocks if needed + if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { hasToolModification := false for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { - if block.CacheControl != nil || (block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileType != nil) { + if block.CacheControl != nil || block.Citations != nil || (block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileType != nil) { hasToolModification = true break } } if hasToolModification { - toolMsgCopy := *msg.ResponsesToolMessage outputCopy := *msg.ResponsesToolMessage.Output outputCopy.ResponsesFunctionToolCallOutputBlocks = make([]schemas.ResponsesMessageContentBlock, len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks)) for j, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { @@ -338,6 +390,7 @@ func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { if needsBlockCopy { blockCopy := block blockCopy.CacheControl = nil + blockCopy.Citations = nil // Strip FileType from file block if blockCopy.ResponsesInputMessageContentBlockFile != nil && blockCopy.ResponsesInputMessageContentBlockFile.FileType != nil { fileCopy := *blockCopy.ResponsesInputMessageContentBlockFile @@ -350,9 +403,13 @@ func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { } } toolMsgCopy.Output = &outputCopy - messagesCopy[i].ResponsesToolMessage = &toolMsgCopy + toolMsgModified = true } } + + if toolMsgModified { + messagesCopy[i].ResponsesToolMessage = &toolMsgCopy + } } } return sonic.Marshal(messagesCopy) @@ -367,6 +424,9 @@ func hasFieldsToStripInChatMessage(msg OpenAIMessage) bool { if block.CacheControl != nil { return true } + if block.Citations != nil { + return true + } if block.File != nil && (block.File.FileType != nil || block.File.FileURL != nil) { return true } @@ -382,13 +442,28 @@ func hasFieldsToStripInResponsesMessage(msg schemas.ResponsesMessage) bool { if block.CacheControl != nil { return true } + if block.Citations != nil { + return true + } if block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileType != nil { return true } + if block.ResponsesOutputMessageContentText != nil && len(block.ResponsesOutputMessageContentText.Annotations) > 0 { + return true + } } } - if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Output != nil { - if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + if msg.ResponsesToolMessage != nil { + // Check if we need to strip fields from web search sources + if msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil { + for _, source := range msg.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Sources { + if source.Title != nil || source.EncryptedContent != nil || source.PageAge != nil { + return true + } + } + } + // Check output blocks + if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { if block.CacheControl != nil { return true @@ -402,6 +477,35 @@ func hasFieldsToStripInResponsesMessage(msg schemas.ResponsesMessage) bool { return false } +// filterSupportedAnnotations filters out unsupported (non-OpenAI native) citation types +// OpenAI supports: file_citation, url_citation, container_file_citation, file_path +func filterSupportedAnnotations(annotations []schemas.ResponsesOutputMessageContentTextAnnotation) []schemas.ResponsesOutputMessageContentTextAnnotation { + if len(annotations) == 0 { + return annotations + } + + supportedAnnotations := make([]schemas.ResponsesOutputMessageContentTextAnnotation, 0, len(annotations)) + for _, annotation := range annotations { + switch annotation.Type { + case "url_citation": + supportedAnnotations = append(supportedAnnotations, schemas.ResponsesOutputMessageContentTextAnnotation{ + Type: "url_citation", + URL: annotation.URL, + Title: annotation.Title, + StartIndex: annotation.StartIndex, + EndIndex: annotation.EndIndex, + }) + case "file_citation", "container_file_citation", "file_path", "text_annotation": + // OpenAI native types - keep them + supportedAnnotations = append(supportedAnnotations, annotation) + default: + continue + } + } + + return supportedAnnotations +} + type OpenAIResponsesRequest struct { Model string `json:"model"` Input OpenAIResponsesRequestInput `json:"input"` diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 831bc8f95..829701bb9 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -1,7 +1,6 @@ package vertex import ( - "context" "fmt" "github.com/bytedance/sonic" @@ -10,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { var jsonBody []byte var err error @@ -43,7 +42,7 @@ func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.B } else { // Convert request to Anthropic format request.Model = deployment - reqBody, err := anthropic.ToAnthropicResponsesRequest(request) + reqBody, err := anthropic.ToAnthropicResponsesRequest(ctx, request) if err != nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerName) } diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 56f1dc184..ffca17b3a 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -1524,7 +1524,7 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch ctx, request, func() (any, error) { - return anthropic.ToAnthropicResponsesRequest(request) + return anthropic.ToAnthropicResponsesRequest(ctx, request) }, providerName, ) diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index 6cbd28b39..089f42bc2 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -708,6 +708,7 @@ type ChatContentBlock struct { // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) CacheControl *CacheControl `json:"cache_control,omitempty"` + Citations *Citations `json:"citations,omitempty"` } type CacheControlType string @@ -807,8 +808,8 @@ func (cm *ChatAssistantMessage) UnmarshalJSON(data []byte) error { // ChatAssistantMessageAnnotation represents an annotation in a response. type ChatAssistantMessageAnnotation struct { - Type string `json:"type"` - Citation ChatAssistantMessageAnnotationCitation `json:"url_citation"` + Type string `json:"type"` + URLCitation ChatAssistantMessageAnnotationCitation `json:"url_citation"` } // ChatAssistantMessageAnnotationCitation represents a citation in a response. diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 201b1b40a..97ede3ee7 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -416,12 +416,12 @@ type ResponsesMessageContentBlock struct { // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) CacheControl *CacheControl `json:"cache_control,omitempty"` + Citations *Citations `json:"citations,omitempty"` } type Citations struct { Enabled *bool `json:"enabled,omitempty"` } - type ResponsesInputMessageContentBlockImage struct { ImageURL *string `json:"image_url,omitempty"` Detail *string `json:"detail,omitempty"` // "low" | "high" | "auto" @@ -459,6 +459,16 @@ type ResponsesOutputMessageContentTextAnnotation struct { Title *string `json:"title,omitempty"` URL *string `json:"url,omitempty"` ContainerID *string `json:"container_id,omitempty"` + + // Anthropic specific fields + StartCharIndex *int `json:"start_char_index,omitempty"` + EndCharIndex *int `json:"end_char_index,omitempty"` + StartPageNumber *int `json:"start_page_number,omitempty"` + EndPageNumber *int `json:"end_page_number,omitempty"` + StartBlockIndex *int `json:"start_block_index,omitempty"` + EndBlockIndex *int `json:"end_block_index,omitempty"` + Source *string `json:"source,omitempty"` + EncryptedIndex *string `json:"encrypted_index,omitempty"` } // ResponsesOutputMessageContentTextLogProb represents log probability information for content. @@ -514,31 +524,55 @@ func (action ResponsesToolMessageActionStruct) MarshalJSON() ([]byte, error) { if action.ResponsesMCPApprovalRequestAction != nil { return Marshal(action.ResponsesMCPApprovalRequestAction) } - return nil, fmt.Errorf("responses tool message action struct is neither a computer tool call action nor a web search tool call action nor a local shell tool call action nor a mcp approval request action") + return nil, fmt.Errorf("responses tool message action struct is empty") } func (action *ResponsesToolMessageActionStruct) UnmarshalJSON(data []byte) error { - var computerToolCallAction ResponsesComputerToolCallAction - if err := Unmarshal(data, &computerToolCallAction); err == nil { + // First, peek at the type field to determine which variant to unmarshal + var typeStruct struct { + Type string `json:"type"` + } + if err := Unmarshal(data, &typeStruct); err != nil { + return fmt.Errorf("failed to peek at type field: %w", err) + } + + // Based on the type, unmarshal into the appropriate variant + switch typeStruct.Type { + case "click", "double_click", "drag", "keypress", "move", "screenshot", "scroll", "type", "wait", "zoom": + var computerToolCallAction ResponsesComputerToolCallAction + if err := Unmarshal(data, &computerToolCallAction); err != nil { + return fmt.Errorf("failed to unmarshal computer tool call action: %w", err) + } action.ResponsesComputerToolCallAction = &computerToolCallAction return nil - } - var webSearchToolCallAction ResponsesWebSearchToolCallAction - if err := Unmarshal(data, &webSearchToolCallAction); err == nil { + + case "search", "open_page", "find": + var webSearchToolCallAction ResponsesWebSearchToolCallAction + if err := Unmarshal(data, &webSearchToolCallAction); err != nil { + return fmt.Errorf("failed to unmarshal web search tool call action: %w", err) + } action.ResponsesWebSearchToolCallAction = &webSearchToolCallAction return nil - } - var localShellToolCallAction ResponsesLocalShellToolCallAction - if err := Unmarshal(data, &localShellToolCallAction); err == nil { + + case "exec": + var localShellToolCallAction ResponsesLocalShellToolCallAction + if err := Unmarshal(data, &localShellToolCallAction); err != nil { + return fmt.Errorf("failed to unmarshal local shell tool call action: %w", err) + } action.ResponsesLocalShellToolCallAction = &localShellToolCallAction return nil - } - var mcpApprovalRequestAction ResponsesMCPApprovalRequestAction - if err := Unmarshal(data, &mcpApprovalRequestAction); err == nil { + + case "mcp_approval_request": + var mcpApprovalRequestAction ResponsesMCPApprovalRequestAction + if err := Unmarshal(data, &mcpApprovalRequestAction); err != nil { + return fmt.Errorf("failed to unmarshal mcp approval request action: %w", err) + } action.ResponsesMCPApprovalRequestAction = &mcpApprovalRequestAction return nil + + default: + return fmt.Errorf("unknown action type: %s", typeStruct.Type) } - return fmt.Errorf("responses tool message action struct is neither a computer tool call action nor a web search tool call action nor a local shell tool call action nor a mcp approval request action") } type ResponsesToolMessageOutputStruct struct { @@ -658,6 +692,7 @@ type ResponsesWebSearchToolCallAction struct { Type string `json:"type"` // "search" | "open_page" | "find" URL *string `json:"url,omitempty"` // Common URL field (OpenPage, Find) Query *string `json:"query,omitempty"` + Queries []string `json:"queries,omitempty"` Sources []ResponsesWebSearchToolCallActionSearchSource `json:"sources,omitempty"` Pattern *string `json:"pattern,omitempty"` } @@ -666,6 +701,11 @@ type ResponsesWebSearchToolCallAction struct { type ResponsesWebSearchToolCallActionSearchSource struct { Type string `json:"type"` // always "url" URL string `json:"url"` + + // Anthropic specific fields + Title *string `json:"title,omitempty"` + EncryptedContent *string `json:"encrypted_content,omitempty"` + PageAge *string `json:"page_age,omitempty"` } // ----------------------------------------------------------------------------- @@ -1061,6 +1101,296 @@ type ResponsesTool struct { *ResponsesToolWebSearchPreview } +// MarshalJSON implements custom JSON marshaling for ResponsesTool +// It merges common fields with the appropriate embedded struct based on type +func (t ResponsesTool) MarshalJSON() ([]byte, error) { + // Start with common fields + result := map[string]interface{}{ + "type": t.Type, + } + + if t.Name != nil { + result["name"] = t.Name + } + if t.Description != nil { + result["description"] = t.Description + } + if t.CacheControl != nil { + result["cache_control"] = t.CacheControl + } + + // Based on type, marshal the appropriate embedded struct + switch t.Type { + case ResponsesToolTypeFunction: + if t.ResponsesToolFunction != nil { + bytes, err := Marshal(t.ResponsesToolFunction) + if err != nil { + return nil, err + } + var funcFields map[string]interface{} + if err := Unmarshal(bytes, &funcFields); err != nil { + return nil, err + } + for k, v := range funcFields { + result[k] = v + } + } + + case ResponsesToolTypeFileSearch: + if t.ResponsesToolFileSearch != nil { + bytes, err := Marshal(t.ResponsesToolFileSearch) + if err != nil { + return nil, err + } + var fileSearchFields map[string]interface{} + if err := Unmarshal(bytes, &fileSearchFields); err != nil { + return nil, err + } + for k, v := range fileSearchFields { + result[k] = v + } + } + + case ResponsesToolTypeComputerUsePreview: + if t.ResponsesToolComputerUsePreview != nil { + bytes, err := Marshal(t.ResponsesToolComputerUsePreview) + if err != nil { + return nil, err + } + var computerFields map[string]interface{} + if err := Unmarshal(bytes, &computerFields); err != nil { + return nil, err + } + for k, v := range computerFields { + result[k] = v + } + } + + case ResponsesToolTypeWebSearch: + if t.ResponsesToolWebSearch != nil { + bytes, err := Marshal(t.ResponsesToolWebSearch) + if err != nil { + return nil, err + } + var webSearchFields map[string]interface{} + if err := Unmarshal(bytes, &webSearchFields); err != nil { + return nil, err + } + for k, v := range webSearchFields { + result[k] = v + } + } + + case ResponsesToolTypeMCP: + if t.ResponsesToolMCP != nil { + bytes, err := Marshal(t.ResponsesToolMCP) + if err != nil { + return nil, err + } + var mcpFields map[string]interface{} + if err := Unmarshal(bytes, &mcpFields); err != nil { + return nil, err + } + for k, v := range mcpFields { + result[k] = v + } + } + + case ResponsesToolTypeCodeInterpreter: + if t.ResponsesToolCodeInterpreter != nil { + bytes, err := Marshal(t.ResponsesToolCodeInterpreter) + if err != nil { + return nil, err + } + var codeInterpreterFields map[string]interface{} + if err := Unmarshal(bytes, &codeInterpreterFields); err != nil { + return nil, err + } + for k, v := range codeInterpreterFields { + result[k] = v + } + } + + case ResponsesToolTypeImageGeneration: + if t.ResponsesToolImageGeneration != nil { + bytes, err := Marshal(t.ResponsesToolImageGeneration) + if err != nil { + return nil, err + } + var imageGenFields map[string]interface{} + if err := Unmarshal(bytes, &imageGenFields); err != nil { + return nil, err + } + for k, v := range imageGenFields { + result[k] = v + } + } + + case ResponsesToolTypeLocalShell: + if t.ResponsesToolLocalShell != nil { + bytes, err := Marshal(t.ResponsesToolLocalShell) + if err != nil { + return nil, err + } + var localShellFields map[string]interface{} + if err := Unmarshal(bytes, &localShellFields); err != nil { + return nil, err + } + for k, v := range localShellFields { + result[k] = v + } + } + + case ResponsesToolTypeCustom: + if t.ResponsesToolCustom != nil { + bytes, err := Marshal(t.ResponsesToolCustom) + if err != nil { + return nil, err + } + var customFields map[string]interface{} + if err := Unmarshal(bytes, &customFields); err != nil { + return nil, err + } + for k, v := range customFields { + result[k] = v + } + } + + case ResponsesToolTypeWebSearchPreview: + if t.ResponsesToolWebSearchPreview != nil { + bytes, err := Marshal(t.ResponsesToolWebSearchPreview) + if err != nil { + return nil, err + } + var webSearchPreviewFields map[string]interface{} + if err := Unmarshal(bytes, &webSearchPreviewFields); err != nil { + return nil, err + } + for k, v := range webSearchPreviewFields { + result[k] = v + } + } + } + + return Marshal(result) +} + +// UnmarshalJSON implements custom JSON unmarshaling for ResponsesTool +// It unmarshals common fields first, then the appropriate embedded struct based on type +func (t *ResponsesTool) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to inspect the type + var raw map[string]interface{} + if err := Unmarshal(data, &raw); err != nil { + return err + } + + // Extract type field + typeValue, ok := raw["type"] + if !ok { + return fmt.Errorf("missing required 'type' field in ResponsesTool") + } + + typeStr, ok := typeValue.(string) + if !ok { + return fmt.Errorf("'type' field must be a string") + } + t.Type = ResponsesToolType(typeStr) + + // Unmarshal common fields + if name, ok := raw["name"].(string); ok { + t.Name = &name + } + if description, ok := raw["description"].(string); ok { + t.Description = &description + } + if cacheControl, ok := raw["cache_control"]; ok { + bytes, err := Marshal(cacheControl) + if err != nil { + return err + } + var cc CacheControl + if err := Unmarshal(bytes, &cc); err != nil { + return err + } + t.CacheControl = &cc + } + + // Based on type, unmarshal into the appropriate embedded struct + switch t.Type { + case ResponsesToolTypeFunction: + var funcTool ResponsesToolFunction + if err := Unmarshal(data, &funcTool); err != nil { + return err + } + t.ResponsesToolFunction = &funcTool + + case ResponsesToolTypeFileSearch: + var fileSearchTool ResponsesToolFileSearch + if err := Unmarshal(data, &fileSearchTool); err != nil { + return err + } + t.ResponsesToolFileSearch = &fileSearchTool + + case ResponsesToolTypeComputerUsePreview: + var computerTool ResponsesToolComputerUsePreview + if err := Unmarshal(data, &computerTool); err != nil { + return err + } + t.ResponsesToolComputerUsePreview = &computerTool + + case ResponsesToolTypeWebSearch: + var webSearchTool ResponsesToolWebSearch + if err := Unmarshal(data, &webSearchTool); err != nil { + return err + } + t.ResponsesToolWebSearch = &webSearchTool + + case ResponsesToolTypeMCP: + var mcpTool ResponsesToolMCP + if err := Unmarshal(data, &mcpTool); err != nil { + return err + } + t.ResponsesToolMCP = &mcpTool + + case ResponsesToolTypeCodeInterpreter: + var codeInterpreterTool ResponsesToolCodeInterpreter + if err := Unmarshal(data, &codeInterpreterTool); err != nil { + return err + } + t.ResponsesToolCodeInterpreter = &codeInterpreterTool + + case ResponsesToolTypeImageGeneration: + var imageGenTool ResponsesToolImageGeneration + if err := Unmarshal(data, &imageGenTool); err != nil { + return err + } + t.ResponsesToolImageGeneration = &imageGenTool + + case ResponsesToolTypeLocalShell: + var localShellTool ResponsesToolLocalShell + if err := Unmarshal(data, &localShellTool); err != nil { + return err + } + t.ResponsesToolLocalShell = &localShellTool + + case ResponsesToolTypeCustom: + var customTool ResponsesToolCustom + if err := Unmarshal(data, &customTool); err != nil { + return err + } + t.ResponsesToolCustom = &customTool + + case ResponsesToolTypeWebSearchPreview: + var webSearchPreviewTool ResponsesToolWebSearchPreview + if err := Unmarshal(data, &webSearchPreviewTool); err != nil { + return err + } + t.ResponsesToolWebSearchPreview = &webSearchPreviewTool + } + + return nil +} + // ResponsesToolFunction represents a tool function type ResponsesToolFunction struct { Parameters *ToolFunctionParameters `json:"parameters,omitempty"` // A JSON schema object describing the parameters @@ -1218,11 +1548,15 @@ type ResponsesToolWebSearch struct { Filters *ResponsesToolWebSearchFilters `json:"filters,omitempty"` // Filters for the search SearchContextSize *string `json:"search_context_size,omitempty"` // "low" | "medium" | "high" UserLocation *ResponsesToolWebSearchUserLocation `json:"user_location,omitempty"` // The approximate location of the user + + // Anthropic only + MaxUses *int `json:"max_uses,omitempty"` // Maximum number of uses for the search } // ResponsesToolWebSearchFilters represents filters for web search type ResponsesToolWebSearchFilters struct { - AllowedDomains []string `json:"allowed_domains"` // Allowed domains for the search + AllowedDomains []string `json:"allowed_domains,omitempty"` // Allowed domains for the search + BlockedDomains []string `json:"blocked_domains,omitempty"` // Blocked domains for the search, only used in anthropic } // ResponsesToolWebSearchUserLocation - The approximate location of the user @@ -1403,7 +1737,9 @@ const ( ResponsesStreamResponseTypeFileSearchCallSearching ResponsesStreamResponseType = "response.file_search_call.searching" ResponsesStreamResponseTypeFileSearchCallResultsAdded ResponsesStreamResponseType = "response.file_search_call.results.added" ResponsesStreamResponseTypeFileSearchCallResultsCompleted ResponsesStreamResponseType = "response.file_search_call.results.completed" + ResponsesStreamResponseTypeWebSearchCallInProgress ResponsesStreamResponseType = "response.web_search_call.in_progress" ResponsesStreamResponseTypeWebSearchCallSearching ResponsesStreamResponseType = "response.web_search_call.searching" + ResponsesStreamResponseTypeWebSearchCallCompleted ResponsesStreamResponseType = "response.web_search_call.completed" ResponsesStreamResponseTypeWebSearchCallResultsAdded ResponsesStreamResponseType = "response.web_search_call.results.added" ResponsesStreamResponseTypeWebSearchCallResultsCompleted ResponsesStreamResponseType = "response.web_search_call.results.completed" diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 9bfe96bd6..70de2ca21 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -639,23 +639,23 @@ func DeepCopyChatMessage(original ChatMessage) ChatMessage { for i, annotation := range original.ChatAssistantMessage.Annotations { copyAnnotation := ChatAssistantMessageAnnotation{ Type: annotation.Type, - Citation: ChatAssistantMessageAnnotationCitation{ - StartIndex: annotation.Citation.StartIndex, - EndIndex: annotation.Citation.EndIndex, - Title: annotation.Citation.Title, + URLCitation: ChatAssistantMessageAnnotationCitation{ + StartIndex: annotation.URLCitation.StartIndex, + EndIndex: annotation.URLCitation.EndIndex, + Title: annotation.URLCitation.Title, }, } - if annotation.Citation.URL != nil { - copyURL := *annotation.Citation.URL - copyAnnotation.Citation.URL = ©URL + if annotation.URLCitation.URL != nil { + copyURL := *annotation.URLCitation.URL + copyAnnotation.URLCitation.URL = ©URL } - if annotation.Citation.Sources != nil { - copySources := *annotation.Citation.Sources - copyAnnotation.Citation.Sources = ©Sources + if annotation.URLCitation.Sources != nil { + copySources := *annotation.URLCitation.Sources + copyAnnotation.URLCitation.Sources = ©Sources } - if annotation.Citation.Type != nil { - copyType := *annotation.Citation.Type - copyAnnotation.Citation.Type = ©Type + if annotation.URLCitation.Type != nil { + copyType := *annotation.URLCitation.Type + copyAnnotation.URLCitation.Type = ©Type } copy.ChatAssistantMessage.Annotations[i] = copyAnnotation } diff --git a/tests/integrations/python/config.yml b/tests/integrations/python/config.yml index 2d2658cd7..1df0ca0ad 100644 --- a/tests/integrations/python/config.yml +++ b/tests/integrations/python/config.yml @@ -183,6 +183,7 @@ provider_scenarios: multiple_tool_calls: true end2end_tool_calling: true automatic_function_calling: true + "web_search": true image_url: true image_base64: true file_input: true @@ -194,6 +195,7 @@ provider_scenarios: embeddings: true thinking: true prompt_caching: false + citations: false list_models: true responses: true responses_image: true @@ -244,6 +246,7 @@ provider_scenarios: multiple_tool_calls: true end2end_tool_calling: true automatic_function_calling: true + web_search: true image_url: true image_base64: true file_input: true @@ -256,6 +259,7 @@ provider_scenarios: embeddings: false thinking: true prompt_caching: true + citations: true list_models: true responses: true responses_image: true @@ -296,6 +300,7 @@ provider_scenarios: embeddings: true thinking: true prompt_caching: false + citations: false list_models: true responses: true responses_image: true @@ -377,6 +382,7 @@ provider_scenarios: embeddings: true thinking: true prompt_caching: true + citations: false list_models: true responses: true responses_image: true @@ -415,6 +421,7 @@ provider_scenarios: embeddings: true thinking: false prompt_caching: false + citations: false list_models: false responses: true responses_image: true @@ -449,6 +456,7 @@ scenario_capabilities: multiple_tool_calls: "tools" end2end_tool_calling: "tools" automatic_function_calling: "tools" + web_search: "chat" image_url: "vision" image_base64: "vision" file_input: "file" @@ -461,6 +469,7 @@ scenario_capabilities: embeddings: "embeddings" thinking: "thinking" prompt_caching: "chat" + citations: "chat" list_models: "chat" langchain_structured_output: "chat" # LangChain structured output uses chat capability pydantic_structured_output: "chat" # Structured output uses chat capability diff --git a/tests/integrations/python/tests/test_anthropic.py b/tests/integrations/python/tests/test_anthropic.py index 3dc1a5567..686bab06a 100644 --- a/tests/integrations/python/tests/test_anthropic.py +++ b/tests/integrations/python/tests/test_anthropic.py @@ -85,6 +85,12 @@ extract_tool_calls, get_api_key, mock_tool_response, + # Citation utilities + CITATION_TEXT_DOCUMENT, + CITATION_MULTI_DOCUMENT_SET, + assert_valid_anthropic_citation, + collect_anthropic_streaming_citations, + create_anthropic_document, ) from .utils.config_loader import get_config, get_model from .utils.parametrize import ( @@ -1773,6 +1779,1037 @@ def test_32_document_text_input(self, anthropic_client, test_config, provider, m assert any(word in content for word in document_keywords), \ f"Response should reference document features. Got: {content}" + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("citations")) + def test_33_citations_pdf(self, anthropic_client, test_config, provider, model): + """Test Case 33: PDF document with page_location citations""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for citations scenario") + + print(f"\n=== Testing PDF Citations (page_location) for provider {provider} ===") + + # Create PDF document using helper + document = create_anthropic_document( + content=FILE_DATA_BASE64, + doc_type="pdf", + title="Test PDF Document" + ) + + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What does this PDF document say? Please cite your sources." + }, + document + ] + }] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + max_tokens=500 + ) + + # Validate basic response + assert_valid_chat_response(response) + assert len(response.content) > 0 + + # Check for citations using helper + has_citations = False + citation_count = 0 + for block in response.content: + if hasattr(block, "citations") and block.citations: + has_citations = True + for citation in block.citations: + citation_count += 1 + # Use common validator + assert_valid_anthropic_citation( + citation, + expected_type="page_location", + document_index=0 + ) + print(f"✓ Citation {citation_count}: pages {citation.start_page_number}-{citation.end_page_number}, " + f"text: '{citation.cited_text[:50]}...'") + + assert has_citations, "Response should contain citations for PDF document" + print(f"✓ PDF citations test passed - Found {citation_count} citations") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("citations")) + def test_34_citations_text(self, anthropic_client, test_config, provider, model): + """Test Case 34: Plain text document with char_location citations""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for citations scenario") + + print(f"\n=== Testing Text Citations (char_location) for provider {provider} ===") + + # Create text document using helper + document = create_anthropic_document( + content=CITATION_TEXT_DOCUMENT, + doc_type="text", + title="Theory of Relativity Overview" + ) + + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "When was General Relativity published and what does it deal with? Please cite your sources." + }, + document + ] + }] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + max_tokens=500 + ) + + # Validate basic response + assert_valid_chat_response(response) + assert len(response.content) > 0 + + # Check for citations using helper + has_citations = False + citation_count = 0 + for block in response.content: + if hasattr(block, "citations") and block.citations: + has_citations = True + for citation in block.citations: + citation_count += 1 + # Use common validator + assert_valid_anthropic_citation( + citation, + expected_type="char_location", + document_index=0 + ) + print(f"✓ Citation {citation_count}: chars {citation.start_char_index}-{citation.end_char_index}, " + f"text: '{citation.cited_text[:50]}...'") + + assert has_citations, "Response should contain citations for text document" + print(f"✓ Text citations test passed - Found {citation_count} citations") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("citations")) + def test_35_citations_multi_document(self, anthropic_client, test_config, provider, model): + """Test Case 35: Multiple documents with citations (document_index validation)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for citations scenario") + + print(f"\n=== Testing Multi-Document Citations for provider {provider} ===") + + # Create multiple documents using helper + documents = [] + for idx, doc_info in enumerate(CITATION_MULTI_DOCUMENT_SET): + doc = create_anthropic_document( + content=doc_info["content"], + doc_type="text", + title=doc_info["title"] + ) + documents.append(doc) + + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "Summarize what each document says. Please cite your sources from each document." + }, + *documents + ] + }] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + max_tokens=600 + ) + + # Validate basic response + assert_valid_chat_response(response) + assert len(response.content) > 0 + + # Check for citations from multiple documents + has_citations = False + citations_by_doc = {0: 0, 1: 0} # Track citations per document + total_citations = 0 + + for block in response.content: + if hasattr(block, "citations") and block.citations: + has_citations = True + for citation in block.citations: + total_citations += 1 + doc_idx = citation.document_index if hasattr(citation, "document_index") else 0 + + # Validate citation + assert_valid_anthropic_citation( + citation, + expected_type="char_location", + document_index=doc_idx + ) + + # Track which document this citation is from + if doc_idx in citations_by_doc: + citations_by_doc[doc_idx] += 1 + + doc_title = citation.document_title if hasattr(citation, "document_title") else "Unknown" + print(f"✓ Citation from doc[{doc_idx}] ({doc_title}): " + f"chars {citation.start_char_index}-{citation.end_char_index}, " + f"text: '{citation.cited_text[:40]}...'") + + assert has_citations, "Response should contain citations" + + # Report statistics + print(f"\n✓ Multi-document citations test passed:") + print(f" - Total citations: {total_citations}") + for doc_idx, count in citations_by_doc.items(): + doc_title = CITATION_MULTI_DOCUMENT_SET[doc_idx]["title"] + print(f" - Document {doc_idx} ({doc_title}): {count} citations") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("citations")) + def test_36_citations_streaming(self, anthropic_client, test_config, provider, model): + """Test Case 36: Text citations with streaming (citations_delta)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for citations scenario") + + print(f"\n=== Testing Streaming Citations (char_location) for provider {provider} ===") + + # Create text document using helper + document = create_anthropic_document( + content=CITATION_TEXT_DOCUMENT, + doc_type="text", + title="Machine Learning Introduction" + ) + + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "Explain the key concepts from this document. Please cite your sources." + }, + document + ] + }] + + stream = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + max_tokens=500, + stream=True + ) + + # Collect streaming content and citations using helper + complete_text, citations, chunk_count = collect_anthropic_streaming_citations(stream) + + # Validate results + assert chunk_count > 0, "Should receive at least one chunk" + assert len(complete_text) > 0, "Should receive text content" + assert len(citations) > 0, "Should collect at least one citation from stream" + + # Validate each citation + for idx, citation in enumerate(citations, 1): + # Use common validator + assert_valid_anthropic_citation( + citation, + expected_type="char_location", + document_index=0 + ) + print(f"✓ Citation {idx}: chars {citation.start_char_index}-{citation.end_char_index}, " + f"text: '{citation.cited_text[:50]}...'") + + print(f"✓ Streaming citations test passed - {len(citations)} citations in {chunk_count} chunks") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("citations")) + def test_37_citations_streaming_pdf(self, anthropic_client, test_config, provider, model): + """Test Case 37: PDF citations with streaming (page_location + citations_delta)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for citations scenario") + + print(f"\n=== Testing Streaming PDF Citations (page_location) for provider {provider} ===") + + # Create PDF document using helper + document = create_anthropic_document( + content=FILE_DATA_BASE64, + doc_type="pdf", + title="Test PDF Document" + ) + + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What does this PDF say? Please cite your sources." + }, + document + ] + }] + + stream = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + max_tokens=500, + stream=True + ) + + # Collect streaming content and citations using helper + complete_text, citations, chunk_count = collect_anthropic_streaming_citations(stream) + + # Validate results + assert chunk_count > 0, "Should receive at least one chunk" + assert len(complete_text) > 0, "Should receive text content" + assert len(citations) > 0, "Should collect at least one citation from stream" + + # Validate each citation - should be page_location for PDF + for idx, citation in enumerate(citations, 1): + # Use common validator + assert_valid_anthropic_citation( + citation, + expected_type="page_location", + document_index=0 + ) + print(f"✓ Citation {idx}: pages {citation.start_page_number}-{citation.end_page_number}, " + f"text: '{citation.cited_text[:50]}...'") + + print(f"✓ Streaming PDF citations test passed - {len(citations)} citations in {chunk_count} chunks") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_38_web_search_non_streaming(self, anthropic_client, test_config, provider, model): + """Test Case 38: Web search tool (non-streaming)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search (Non-Streaming) for provider {provider} ===") + + # Create web search tool + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5 + } + + messages = [ + { + "role": "user", + "content": "What is a positive news story from today?" + } + ] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + assert len(response.content) > 0, "Content should not be empty" + + # Check for web search tool use + has_web_search = False + has_search_results = False + has_citations = False + search_query = None + + for block in response.content: + if hasattr(block, "type"): + # Check for server_tool_use with web_search + if block.type == "server_tool_use" and hasattr(block, "name") and block.name == "web_search": + has_web_search = True + if hasattr(block, "input") and "query" in block.input: + search_query = block.input["query"] + print(f"✓ Found web search with query: {search_query}") + + # Check for web_search_tool_result + elif block.type == "web_search_tool_result": + has_search_results = True + if hasattr(block, "content") and block.content: + result_count = len(block.content) + print(f"✓ Found {result_count} search results") + + # Log first few results + for i, result in enumerate(block.content[:3]): + if hasattr(result, "url") and hasattr(result, "title"): + print(f" Result {i+1}: {result.title}") + + # Check for text with citations + elif block.type == "text": + if hasattr(block, "citations") and block.citations: + has_citations = True + citation_count = len(block.citations) + print(f"✓ Found {citation_count} citations in response") + + # Validate citation structure + for citation in block.citations[:3]: + assert hasattr(citation, "type"), "Citation should have type" + assert hasattr(citation, "url"), "Citation should have URL" + assert hasattr(citation, "title"), "Citation should have title" + assert hasattr(citation, "cited_text"), "Citation should have cited_text" + print(f" Citation: {citation.title}") + + # Validate that web search was performed + assert has_web_search, "Response should contain web_search tool use" + assert has_search_results, "Response should contain web search results" + assert search_query is not None, "Web search should have a query" + + + print(f"✓ Web search (non-streaming) test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_39_web_search_streaming(self, anthropic_client, test_config, provider, model): + """Test Case 39: Web search tool (streaming)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search (Streaming) for provider {provider} ===") + + # Create web search tool with user location + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5, + "user_location": { + "type": "approximate", + "city": "New York", + "region": "New York", + "country": "US", + "timezone": "America/New_York" + } + } + + messages = [ + { + "role": "user", + "content": "what was a positive news story from today??" + } + ] + + stream = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048, + stream=True + ) + + # Collect streaming events + text_parts = [] + search_queries = [] + search_results = [] + citations = [] + chunk_count = 0 + has_server_tool_use = False + has_search_tool_result = False + has_citation_delta = False + + for event in stream: + chunk_count += 1 + + if hasattr(event, "type"): + event_type = event.type + + # Handle content_block_start for tool use + if event_type == "content_block_start": + if hasattr(event, "content_block") and event.content_block: + block = event.content_block + + # Check for server_tool_use + if hasattr(block, "type") and block.type == "server_tool_use": + if hasattr(block, "name") and block.name == "web_search": + has_server_tool_use = True + print(f"✓ Web search tool use started (block id: {block.id if hasattr(block, 'id') else 'unknown'})") + + # Check for web_search_tool_result + elif hasattr(block, "type") and block.type == "web_search_tool_result": + print(f"block: {block}") + has_search_tool_result = True + if hasattr(block, "content") and block.content: + result_count = len(block.content) + print(f"✓ Received {result_count} search results") + + # Collect search results + for result in block.content: + if hasattr(result, "url") and hasattr(result, "title"): + search_results.append({ + "url": result.url, + "title": result.title + }) + + # Handle content_block_delta for queries and text + elif event_type == "content_block_delta": + if hasattr(event, "delta") and event.delta: + delta = event.delta + + # Check for text_delta + if hasattr(delta, "type") and delta.type == "text_delta": + if hasattr(delta, "text"): + text_parts.append(delta.text) + + # Check for citations_delta + elif hasattr(delta, "type") and delta.type == "citations_delta": + has_citation_delta = True + if hasattr(delta, "citation"): + citation = delta.citation + citations.append(citation) + + if hasattr(citation, "title"): + print(f" Received citation: {citation.title}") + + # Safety check + if chunk_count > 5000: + break + + # Combine collected content + complete_text = "".join(text_parts) + + # Validate results + assert chunk_count > 0, "Should receive at least one chunk" + assert has_server_tool_use, "Should detect web search tool use in streaming" + assert has_search_tool_result, "Should receive search results in streaming" + assert len(search_results) > 0, "Should collect search results from stream" + assert len(complete_text) > 0, "Should receive text content about weather" + + print("✓ Streaming validation:") + print(f" - Chunks received: {chunk_count}") + print(f" - Search results: {len(search_results)}") + print(f" - Citations: {len(citations)}") + print(f" - Text length: {len(complete_text)} characters") + print(f" - First 150 chars: {complete_text[:150]}...") + + # Log a few search results + if len(search_results) > 0: + print("✓ Search results:") + for i, result in enumerate(search_results[:3]): + print(f" {i+1}. {result['title']}") + + print("✓ Web search (streaming) test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_40_web_search_allowed_domains(self, anthropic_client, test_config, provider, model): + """Test Case 40: Web search with allowed_domains filter""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search with Allowed Domains for provider {provider} ===") + + # Create web search tool with allowed domains + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "allowed_domains": ["en.wikipedia.org", "britannica.com"], + "max_uses": 5 + } + + messages = [ + { + "role": "user", + "content": "Who was Albert Einstein? Please search for this information." + } + ] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + assert len(response.content) > 0, "Content should not be empty" + + # Collect search results + search_results = [] + for block in response.content: + if hasattr(block, "type") and block.type == "web_search_tool_result": + if hasattr(block, "content") and block.content: + for result in block.content: + if hasattr(result, "url") and hasattr(result, "title"): + search_results.append(result) + print(f"✓ Found result: {result.title} - {result.url}") + + # Validate domain filtering + from .utils.common import validate_domain_filter + if len(search_results) > 0: + validate_domain_filter(search_results, allowed=["wikipedia.org", "britannica.com"]) + print(f"✓ All {len(search_results)} results respect allowed_domains filter") + + print(f"✓ Web search with allowed_domains test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_41_web_search_blocked_domains(self, anthropic_client, test_config, provider, model): + """Test Case 41: Web search with blocked_domains filter""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + # skip for openai + if provider == "openai": + pytest.skip("OpenAI does not support blocked_domains filter") + + print(f"\n=== Testing Web Search with Blocked Domains for provider {provider} ===") + + # Create web search tool with blocked domains + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "blocked_domains": ["reddit.com", "twitter.com", "x.com"], + "max_uses": 5 + } + + messages = [ + { + "role": "user", + "content": "What are recent developments in artificial intelligence?" + } + ] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + + # Collect search results + search_results = [] + for block in response.content: + if hasattr(block, "type") and block.type == "web_search_tool_result": + if hasattr(block, "content") and block.content: + for result in block.content: + if hasattr(result, "url"): + search_results.append(result) + print(f"✓ Found result: {result.url}") + + # Validate domain filtering + from .utils.common import validate_domain_filter + if len(search_results) > 0: + validate_domain_filter(search_results, blocked=["reddit.com", "twitter.com", "x.com"]) + print(f"✓ All {len(search_results)} results respect blocked_domains filter") + + print(f"✓ Web search with blocked_domains test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_42_web_search_multi_turn(self, anthropic_client, test_config, provider, model): + """Test Case 42: Web search in multi-turn conversation""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Multi-Turn Conversation for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5 + } + + # First turn: Ask about a topic + messages = [ + { + "role": "user", + "content": "What is quantum computing?" + } + ] + + response1 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + assert response1 is not None, "First response should not be None" + print(f"✓ First turn completed") + + # Add assistant response to conversation + messages.append({ + "role": "assistant", + "content": serialize_anthropic_content(response1.content) + }) + + # Second turn: Follow-up question + messages.append({ + "role": "user", + "content": "How is it different from classical computing?" + }) + + response2 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + assert response2 is not None, "Second response should not be None" + assert hasattr(response2, "content"), "Second response should have content" + assert len(response2.content) > 0, "Second response content should not be empty" + + # Validate that context was maintained + has_text_response = False + for block in response2.content: + if hasattr(block, "type") and block.type == "text": + if hasattr(block, "text") and len(block.text) > 0: + has_text_response = True + print(f"✓ Second turn response (first 150 chars): {block.text[:150]}...") + + assert has_text_response, "Second turn should have text response" + print(f"✓ Multi-turn web search conversation test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_43_web_search_citation_validation(self, anthropic_client, test_config, provider, model): + """Test Case 43: Validate web search citation structure""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Citation Validation for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5 + } + + messages = [ + { + "role": "user", + "content": "What is the capital of France?" + } + ] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Find citations in response + citations_found = [] + for block in response.content: + if hasattr(block, "type") and block.type == "text": + if hasattr(block, "citations") and block.citations: + for citation in block.citations: + citations_found.append(citation) + + # Validate citation structure + from .utils.common import assert_valid_web_search_citation + if len(citations_found) > 0: + print(f"✓ Found {len(citations_found)} citations") + for i, citation in enumerate(citations_found[:3]): + assert_valid_web_search_citation(citation, sdk_type="anthropic") + print(f" Citation {i+1}: {citation.title}") + print(f" URL: {citation.url}") + print(f" Cited text (first 50 chars): {citation.cited_text[:50] if citation.cited_text else 'N/A'}...") + print(f"✓ All citations have valid structure") + else: + print(f"⚠ No citations found (may be acceptable)") + + print(f"✓ Citation validation test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_44_web_search_streaming_event_order(self, anthropic_client, test_config, provider, model): + """Test Case 44: Validate web search streaming event sequence""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Streaming Event Order for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 3 + } + + messages = [ + { + "role": "user", + "content": "What is the Eiffel Tower?" + } + ] + + stream = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048, + stream=True + ) + + # Track event sequence + event_sequence = [] + + for event in stream: + if hasattr(event, "type"): + event_type = event.type + event_sequence.append(event_type) + + # Log key events + if event_type == "content_block_start": + if hasattr(event, "content_block"): + block_type = getattr(event.content_block, "type", "unknown") + print(f"✓ Event: content_block_start ({block_type})") + elif event_type == "content_block_stop": + print(f"✓ Event: content_block_stop") + elif event_type == "content_block_delta": + if hasattr(event, "delta") and hasattr(event.delta, "type"): + delta_type = event.delta.type + if delta_type == "input_json_delta": + print(f"✓ Event: content_block_delta (input_json_delta)") + + # Validate expected event types are present + assert "message_start" in event_sequence, "Should have message_start event" + assert "content_block_start" in event_sequence, "Should have content_block_start events" + assert "content_block_stop" in event_sequence, "Should have content_block_stop events" + assert "message_stop" in event_sequence, "Should have message_stop event" + + print(f"✓ Received {len(event_sequence)} total events") + print(f"✓ Event sequence validation passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_45_web_search_with_prompt_caching(self, anthropic_client, test_config, provider, model): + """Test Case 45: Web search with prompt caching""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search with Prompt Caching for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 3 + } + + # First request with cache breakpoint + messages = [ + { + "role": "user", + "content": "What is the current population of Tokyo?" + } + ] + + response1 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=1500 + ) + + assert response1 is not None, "First response should not be None" + + # Check if cache was written + if hasattr(response1, "usage"): + cache_write_tokens = getattr(response1.usage, "cache_creation_input_tokens", 0) + print(f"✓ First request - cache_creation_input_tokens: {cache_write_tokens}") + + # Add assistant response with cache control + messages.append({ + "role": "assistant", + "content": serialize_anthropic_content(response1.content) + }) + + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "What about its GDP?", + "cache_control": {"type": "ephemeral"} + } + ] + }) + + # Second request should benefit from caching + response2 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=1500 + ) + + assert response2 is not None, "Second response should not be None" + + # Check if cache was read + if hasattr(response2, "usage"): + cache_read_tokens = getattr(response2.usage, "cache_read_input_tokens", 0) + print(f"✓ Second request - cache_read_input_tokens: {cache_read_tokens}") + + if cache_read_tokens > 0: + print(f"✓ Successfully read {cache_read_tokens} tokens from cache") + + print(f"✓ Prompt caching test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_47_web_search_error_handling(self, anthropic_client, test_config, provider, model): + """Test Case 47: Web search error code handling""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Error Handling for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5 + } + + # Try with an extremely long query that might trigger query_too_long error + very_long_query = "What is " + ("the meaning of life and the universe " * 50) + + messages = [ + { + "role": "user", + "content": very_long_query[:1000] # Limit to reasonable length + } + ] + + try: + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Check response structure + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + + # Look for any error structures in the response + has_error = False + for block in response.content: + if hasattr(block, "type") and block.type == "web_search_tool_result": + if hasattr(block, "content") and isinstance(block.content, dict): + if "error_code" in block.content: + has_error = True + error_code = block.content["error_code"] + print(f"✓ Found error code: {error_code}") + + if not has_error: + print(f"✓ Request handled successfully (no errors triggered)") + + except Exception as e: + # Some errors might be raised as exceptions + print(f"✓ Exception caught (expected for error scenarios): {type(e).__name__}") + + print(f"✓ Error handling test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_48_web_search_no_results_graceful(self, anthropic_client, test_config, provider, model): + """Test Case 48: Web search with query that may return no results""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search No Results Handling for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 3 + } + + # Use a very specific/nonsensical query + messages = [ + { + "role": "user", + "content": "Find information about xyzabc123nonexistent456topic789" + } + ] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Validate graceful handling + assert response is not None, "Response should not be None" + assert hasattr(response, "content"), "Response should have content" + assert len(response.content) > 0, "Content should not be empty" + + # Check for search attempt + has_search_attempt = False + has_response_text = False + + for block in response.content: + if hasattr(block, "type"): + if block.type == "server_tool_use" and hasattr(block, "name") and block.name == "web_search": + has_search_attempt = True + print(f"✓ Web search was attempted") + elif block.type == "text" and hasattr(block, "text"): + has_response_text = True + print(f"✓ Response text present (first 100 chars): {block.text[:100]}...") + + assert has_search_attempt, "Should attempt web search" + assert has_response_text, "Should provide text response even with no/few results" + + print(f"✓ No results graceful handling test passed!") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("web_search")) + def test_49_web_search_sources_validation(self, anthropic_client, test_config, provider, model): + """Test Case 49: Comprehensive web search sources validation""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Sources Validation for provider {provider} ===") + + web_search_tool = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5 + } + + messages = [ + { + "role": "user", + "content": "What are the main programming languages used for web development?" + } + ] + + response = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=messages, + tools=[web_search_tool], + max_tokens=2048 + ) + + # Collect all search sources + all_sources = [] + for block in response.content: + if hasattr(block, "type") and block.type == "web_search_tool_result": + if hasattr(block, "content") and block.content: + for result in block.content: + if hasattr(result, "type") and result.type == "web_search_result": + all_sources.append(result) + + # Validate sources using helper + from .utils.common import assert_web_search_sources_valid + if len(all_sources) > 0: + assert_web_search_sources_valid(all_sources) + print(f"✓ Found and validated {len(all_sources)} search sources") + + # Log details of first few sources + for i, source in enumerate(all_sources[:3]): + print(f" Source {i+1}:") + print(f" URL: {source.url}") + print(f" Title: {source.title if hasattr(source, 'title') else 'N/A'}") + if hasattr(source, "page_age"): + print(f" Page age: {source.page_age}") + if hasattr(source, "encrypted_content"): + print(f" Encrypted content: Present") + else: + print(f"⚠ No search sources found (may indicate no search was performed)") + + print(f"✓ Sources validation test passed!") + # Additional helper functions specific to Anthropic def serialize_anthropic_content(content_blocks: List[Any]) -> List[Dict[str, Any]]: diff --git a/tests/integrations/python/tests/test_openai.py b/tests/integrations/python/tests/test_openai.py index 342f1e15a..5b0f9dc8c 100644 --- a/tests/integrations/python/tests/test_openai.py +++ b/tests/integrations/python/tests/test_openai.py @@ -145,6 +145,8 @@ get_provider_voices, mock_tool_response, skip_if_no_api_key, + # Citation utilities + assert_valid_openai_annotation, ) from .utils.config_loader import get_config, get_model from .utils.parametrize import ( @@ -2832,3 +2834,416 @@ def test_51c_input_tokens_long_text(self, provider, model, vk_enabled): f"Long text should have >100 tokens, got {response.input_tokens}" ) + # ========================================================================= + # WEB SEARCH TOOL TEST CASES + # ========================================================================= + + @pytest.mark.parametrize("provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("web_search")) + def test_52_web_search_non_streaming(self, provider, model, vk_enabled): + """Test Case 52: Web search tool (non-streaming) using Responses API""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search (Non-Streaming) for provider {provider} ===") + + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Use Responses API with web search tool + response = client.responses.create( + model=format_provider_model(provider, model), + tools=[{"type": "web_search"}], + input="What is the current weather in New York City today?", + max_output_tokens=1200, + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "output"), "Response should have output" + assert response.output is not None, "Output should not be None" + assert len(response.output) > 0, "Output should not be empty" + + # Check for web_search_call in output + has_web_search_call = False + has_message_output = False + has_citations = False + search_status = None + output_text = "" + + for output_item in response.output: + # Check for web_search_call + if hasattr(output_item, "type") and output_item.type == "web_search_call": + has_web_search_call = True + if hasattr(output_item, "status"): + search_status = output_item.status + print(f"✓ Found web_search_call with status: {search_status}") + + # Check for search action details + if hasattr(output_item, "action"): + action = output_item.action + if hasattr(action, "query"): + print(f"✓ Search query: {action.query}") + if hasattr(action, "sources") and action.sources: + print(f"✓ Found {len(action.sources)} sources") + + # Check for message output with content + elif hasattr(output_item, "type") and output_item.type == "message": + has_message_output = True + if hasattr(output_item, "content") and output_item.content: + for content_block in output_item.content: + if hasattr(content_block, "type") and content_block.type == "output_text": + if hasattr(content_block, "text"): + output_text = content_block.text + print(f"✓ Found text output (first 150 chars): {output_text[:150]}...") + + # Check for annotations (citations) from web search + if hasattr(content_block, "annotations") and content_block.annotations: + has_citations = True + citation_count = len(content_block.annotations) + print(f"✓ Found {citation_count} citations") + + # Validate citation structure using helper + for i, annotation in enumerate(content_block.annotations[:3]): + assert_valid_openai_annotation(annotation, expected_type="url_citation") + if hasattr(annotation, "url"): + print(f" Citation {i+1}: {annotation.url}") + + # Validate web search was performed + assert has_web_search_call, "Response should contain web_search_call" + assert search_status == "completed", f"Web search should be completed, got status: {search_status}" + assert has_message_output, "Response should contain message output" + assert len(output_text) > 0, "Message should have text content" + + # Validate content mentions weather + text_lower = output_text.lower() + weather_keywords = ["weather", "temperature", "forecast", "rain", "snow", "wind", "sunny", "cloudy", "degrees", + "cold", "hot", "warm", "cool", "chilly", "blustery", "storm", "clear", "humid", "dry"] + assert any(keyword in text_lower for keyword in weather_keywords), \ + f"Response should mention weather-related information. Got: {output_text[:300]}..." + + # Validate usage information + if hasattr(response, "usage"): + print(f"✓ Token usage - Input: {response.usage.input_tokens}, Output: {response.usage.output_tokens}") + + print(f"✓ Web search (non-streaming) test passed!") + + @pytest.mark.parametrize("provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("web_search")) + def test_53_web_search_streaming(self, provider, model, vk_enabled): + """Test Case 53: Web search tool (streaming) using Responses API""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search (Streaming) for provider {provider} ===") + + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Use Responses API with web search tool and user location + stream = client.responses.create( + model=format_provider_model(provider, model), + tools=[{ + "type": "web_search", + "user_location": { + "type": "approximate", + "country": "US", + "city": "New York", + "region": "New York", + "timezone": "America/New_York" + } + }], + input="What's the weather in NYC today?", + include=["web_search_call.action.sources"], + max_output_tokens=1200, + stream=True + ) + + # Collect streaming events + text_parts = [] + chunk_count = 0 + has_web_search_call = False + has_message_output = False + citations = [] + search_queries = [] + + for chunk in stream: + chunk_count += 1 + + if hasattr(chunk, "type"): + chunk_type = chunk.type + + # Handle output_item.added event + if chunk_type == "response.output_item.added": + if hasattr(chunk, "item"): + item = chunk.item + # Check for web_search_call + if hasattr(item, "type") and item.type == "web_search_call": + has_web_search_call = True + print(f"✓ Web search call started (id: {item.id if hasattr(item, 'id') else 'unknown'})") + + # Check for message output + elif hasattr(item, "type") and item.type == "message": + has_message_output = True + + # Handle output_item.done event for completed items + elif chunk_type == "response.output_item.done": + if hasattr(chunk, "item"): + item = chunk.item + + # Check web_search_call completion with action details + if hasattr(item, "type") and item.type == "web_search_call": + if hasattr(item, "action"): + action = item.action + if hasattr(action, "query"): + search_queries.append(action.query) + print(f"✓ Search query: {action.query}") + if hasattr(action, "sources") and action.sources: + print(f"✓ Found {len(action.sources)} sources") + + # Handle content.text.delta for streaming text + elif chunk_type == "response.output_text.delta": + if hasattr(chunk, "delta"): + text_parts.append(chunk.delta) + + # Handle content.annotation.added for citations + elif chunk_type == "response.output_text.annotation.added": + if hasattr(chunk, "annotation"): + annotation = chunk.annotation + citations.append(annotation) + + # Validate citation using helper + assert_valid_openai_annotation(annotation, expected_type="url_citation") + + if hasattr(annotation, "url") and hasattr(annotation, "title"): + print(f" Citation received: {annotation.title}") + + # Safety check + if chunk_count > 5000: + break + + # Combine collected text + complete_text = "".join(text_parts) + + # Validate results + assert chunk_count > 0, "Should receive at least one chunk" + assert has_web_search_call, "Should detect web search call in streaming" + assert has_message_output, "Should detect message output in streaming" + assert len(complete_text) > 0, "Should receive text content" + + # Validate text mentions weather + text_lower = complete_text.lower() + weather_keywords = ["weather", "temperature", "forecast", "rain", "snow", "wind", "sunny", "cloudy", "degrees", + "cold", "hot", "warm", "cool", "chilly", "blustery", "storm", "clear", "humid", "dry"] + assert any(keyword in text_lower for keyword in weather_keywords), \ + f"Response should mention weather-related information. Got: {complete_text[:200]}..." + + print(f"✓ Streaming validation:") + print(f" - Chunks received: {chunk_count}") + print(f" - Search queries: {len(search_queries)}") + print(f" - Citations: {len(citations)}") + print(f" - Text length: {len(complete_text)} characters") + print(f" - First 150 chars: {complete_text[:150]}...") + + # Validate all citations using helper + if len(citations) > 0: + for citation in citations: + assert_valid_openai_annotation(citation, expected_type="url_citation") + + print(f"✓ Web search (streaming) test passed!") + + @pytest.mark.parametrize("provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("web_search")) + def test_54_web_search_annotation_conversion(self, provider, model, vk_enabled): + """Test Case 54: Validate Anthropic citations convert to OpenAI annotations correctly""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Annotation Conversion for provider {provider} ===") + + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + response = client.responses.create( + model=format_provider_model(provider, model), + tools=[{"type": "web_search"}], + input="What is the speed of light in a vacuum use web search tool?", + include=["web_search_call.action.sources"], + max_output_tokens=1500, + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "output"), "Response should have output" + + # Collect and validate annotations + annotations_found = [] + for output_item in response.output: + if hasattr(output_item, "type") and output_item.type == "message": + if hasattr(output_item, "content") and output_item.content: + for content_block in output_item.content: + if hasattr(content_block, "type") and content_block.type == "output_text": + if hasattr(content_block, "annotations") and content_block.annotations: + for annotation in content_block.annotations: + annotations_found.append(annotation) + + # Validate annotation structure + if len(annotations_found) > 0: + print(f"✓ Found {len(annotations_found)} annotations") + for i, annotation in enumerate(annotations_found[:3]): + assert_valid_openai_annotation(annotation, expected_type="url_citation") + print(f" Annotation {i+1}:") + print(f" Type: {annotation.type}") + print(f" URL: {annotation.url if hasattr(annotation, 'url') else 'N/A'}") + if hasattr(annotation, "title"): + print(f" Title: {annotation.title}") + # Check for encrypted_index preservation + if hasattr(annotation, "encrypted_index"): + print(f" Encrypted index present: ✓") + + print(f"✓ All annotations have valid url_citation structure") + else: + print(f"⚠ No annotations found") + + print(f"✓ Annotation conversion test passed!") + + @pytest.mark.parametrize("provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("web_search")) + def test_55_web_search_user_location(self, provider, model, vk_enabled): + """Test Case 55: Web search with user location for localized results""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search with User Location for provider {provider} ===") + + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Test with specific location + response = client.responses.create( + model=format_provider_model(provider, model), + tools=[{ + "type": "web_search", + "user_location": { + "type": "approximate", + "city": "San Francisco", + "region": "California", + "country": "US", + "timezone": "America/Los_Angeles" + } + }], + input="What is the weather like today?", + max_output_tokens=1200, + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "output"), "Response should have output" + assert len(response.output) > 0, "Output should not be empty" + + # Check for web_search_call with status + has_web_search = False + has_message = False + + for output_item in response.output: + if hasattr(output_item, "type"): + if output_item.type == "web_search_call": + has_web_search = True + print(f"✓ Web search executed") + elif output_item.type == "message": + has_message = True + + assert has_web_search, "Should perform web search" + assert has_message, "Should have message response" + + print(f"✓ User location test passed!") + + @pytest.mark.parametrize("provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("web_search")) + def test_56_web_search_wildcard_domains(self, provider, model, vk_enabled): + """Test Case 56: Web search with wildcard domain patterns""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search with Wildcard Domains for provider {provider} ===") + + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # Use wildcard domain patterns + response = client.responses.create( + model=format_provider_model(provider, model), + tools=[{ + "type": "web_search", + "allowed_domains": ["wikipedia.org/*", "*.edu"] + }], + input="What is machine learning use web search tool?", + include=["web_search_call.action.sources"], + max_output_tokens=1500, + ) + + # Validate basic response + assert response is not None, "Response should not be None" + assert hasattr(response, "output"), "Response should have output" + + # Collect search sources + search_sources = [] + for output_item in response.output: + if hasattr(output_item, "type") and output_item.type == "web_search_call": + if hasattr(output_item, "action") and hasattr(output_item.action, "sources"): + if output_item.action.sources: + search_sources.extend(output_item.action.sources) + + if len(search_sources) > 0: + print(f"✓ Found {len(search_sources)} search sources") + for i, source in enumerate(search_sources[:3]): + if hasattr(source, "url"): + print(f" Source {i+1}: {source.url}") + + print(f"✓ Wildcard domains test passed!") + + @pytest.mark.parametrize("provider,model,vk_enabled", get_cross_provider_params_with_vk_for_scenario("web_search")) + def test_57_web_search_multi_turn_openai(self, provider, model, vk_enabled): + """Test Case 57: Web search in multi-turn conversation (OpenAI SDK)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for web_search scenario") + + print(f"\n=== Testing Web Search Multi-Turn (OpenAI SDK) for provider {provider} ===") + + client = get_provider_openai_client(provider, vk_enabled=vk_enabled) + + # First turn + input_messages = [ + {"role": "user", "content": "What is renewable energy use web search tool?"} + ] + + response1 = client.responses.create( + model=format_provider_model(provider, model), + tools=[{"type": "web_search"}], + input=input_messages, + max_output_tokens=1500, + ) + + assert response1 is not None, "First response should not be None" + assert hasattr(response1, "output"), "First response should have output" + + # Collect first turn output for context + print(f"✓ First turn completed with {len(response1.output)} output items") + + # Second turn with follow-up + # Add each output item from the first response + for output_item in response1.output: + input_messages.append(output_item) + input_messages.append({"role": "user", "content": "What are the main types of renewable energy?"}) + + response2 = client.responses.create( + model=format_provider_model(provider, model), + tools=[{"type": "web_search"}], + input=input_messages, + max_output_tokens=1500, + ) + + assert response2 is not None, "Second response should not be None" + assert hasattr(response2, "output"), "Second response should have output" + assert len(response2.output) > 0, "Second response should have content" + + # Validate second turn has message response + has_message = False + for output_item in response2.output: + if hasattr(output_item, "type") and output_item.type == "message": + has_message = True + + assert has_message, "Second turn should have message response" + print(f"✓ Second turn completed with {len(response2.output)} output items") + print(f"✓ Multi-turn conversation test passed!") + diff --git a/tests/integrations/python/tests/utils/common.py b/tests/integrations/python/tests/utils/common.py index e5a3f65f0..d825f6242 100644 --- a/tests/integrations/python/tests/utils/common.py +++ b/tests/integrations/python/tests/utils/common.py @@ -2615,4 +2615,406 @@ def assert_valid_input_tokens_response(response: Any, library: str): assert hasattr(response, "input_tokens"), "Response should have input_tokens attribute" assert isinstance(response.input_tokens, int), "input_tokens should be an integer" assert response.input_tokens > 0, f"input_tokens should be positive, got {response.input_tokens}" + + +# ========================================================================= +# CITATIONS TEST DATA AND UTILITIES +# ========================================================================= + +# Test document content for citations +CITATION_TEXT_DOCUMENT = """The Theory of Relativity was developed by Albert Einstein in the early 20th century. +It consists of two parts: Special Relativity published in 1905, and General Relativity published in 1915. + +Special Relativity deals with objects moving at constant velocities and introduced the famous equation E=mc². +General Relativity extends this to accelerating objects and provides a new understanding of gravity. + +Einstein's work revolutionized our understanding of space, time, and gravity, and its predictions have been +confirmed by numerous experiments and observations over the past century.""" + +# Multiple documents for testing document_index +CITATION_MULTI_DOCUMENT_SET = [ + { + "title": "Physics Document", + "content": """Quantum mechanics is a fundamental theory in physics that describes the behavior of matter and energy at the atomic and subatomic level. +It was developed in the early 20th century by physicists including Max Planck, Albert Einstein, Niels Bohr, and Werner Heisenberg.""" + }, + { + "title": "Chemistry Document", + "content": """The periodic table organizes chemical elements by their atomic number, electron configuration, and recurring chemical properties. +It was first published by Dmitri Mendeleev in 1869 and has become a fundamental tool in chemistry.""" + } +] + + +def create_anthropic_document( + content: str, + doc_type: str, + title: str = "Test Document", + citations_enabled: bool = True +) -> Dict[str, Any]: + """ + Create a properly formatted document block for Anthropic API with citations. + + Args: + content: Document content (text or base64) + doc_type: Document type - "text", "pdf", or "base64" + title: Document title + citations_enabled: Whether to enable citations + + Returns: + Formatted document block dict + """ + document = { + "type": "document", + "title": title, + "citations": {"enabled": citations_enabled} + } + + if doc_type == "text": + document["source"] = { + "type": "text", + "media_type": "text/plain", + "data": content + } + elif doc_type == "pdf" or doc_type == "base64": + document["source"] = { + "type": "base64", + "media_type": "application/pdf", + "data": content + } + else: + raise ValueError(f"Unsupported doc_type: {doc_type}. Use 'text', 'pdf', or 'base64'") + + return document + + +def validate_citation_indices(citation: Any, citation_type: str) -> None: + """ + Validate citation indices based on type. + + Args: + citation: Citation object to validate + citation_type: Expected citation type (char_location, page_location) + """ + if citation_type == "char_location": + # Character indices: 0-indexed, exclusive end + assert hasattr(citation, "start_char_index"), "char_location should have start_char_index" + assert hasattr(citation, "end_char_index"), "char_location should have end_char_index" + assert citation.start_char_index >= 0, "start_char_index should be >= 0" + assert citation.end_char_index > citation.start_char_index, \ + f"end_char_index ({citation.end_char_index}) should be > start_char_index ({citation.start_char_index})" + + elif citation_type == "page_location": + # Page numbers: 1-indexed, exclusive end + assert hasattr(citation, "start_page_number"), "page_location should have start_page_number" + assert hasattr(citation, "end_page_number"), "page_location should have end_page_number" + assert citation.start_page_number >= 1, "start_page_number should be >= 1 (1-indexed)" + assert citation.end_page_number > citation.start_page_number, \ + f"end_page_number ({citation.end_page_number}) should be > start_page_number ({citation.start_page_number})" + + +def assert_valid_anthropic_citation( + citation: Any, + expected_type: str, + document_index: int = 0 +) -> None: + """ + Assert that an Anthropic citation is valid and matches expected structure. + + Args: + citation: Citation object from Anthropic response + expected_type: Expected citation type (char_location, page_location) + document_index: Expected document index (0-indexed) + """ + # Check basic structure + assert hasattr(citation, "type"), "Citation should have type field" + assert citation.type == expected_type, \ + f"Citation type should be {expected_type}, got {citation.type}" + + # Check required fields + assert hasattr(citation, "cited_text"), "Citation should have cited_text" + assert isinstance(citation.cited_text, str), "cited_text should be a string" + assert len(citation.cited_text) > 0, "cited_text should not be empty" + + # Check document reference + assert hasattr(citation, "document_index"), "Citation should have document_index" + assert citation.document_index == document_index, \ + f"document_index should be {document_index}, got {citation.document_index}" + + # Check document title (optional but common) + if hasattr(citation, "document_title"): + assert isinstance(citation.document_title, str), "document_title should be a string" + + # Validate type-specific indices + validate_citation_indices(citation, expected_type) + + +def assert_valid_openai_annotation( + annotation: Any, + expected_type: str +) -> None: + """ + Assert that an OpenAI annotation is valid and matches expected structure. + + Args: + annotation: Annotation object from OpenAI Responses API + expected_type: Expected annotation type (file_citation, url_citation, etc.) + """ + if isinstance(annotation, dict): + ann_type = annotation.get("type") + assert ann_type == expected_type, f"Annotation type should be {expected_type}, got {ann_type}" + getter = annotation.get + has = annotation.__contains__ + else: + assert hasattr(annotation, "type"), "Annotation should have type field" + assert annotation.type == expected_type, f"Annotation type should be {expected_type}, got {annotation.type}" + getter = lambda k: getattr(annotation, k, None) + has = lambda k: hasattr(annotation, k) + + # Validate based on type + if expected_type == "file_citation": + if has("file_id"): + assert isinstance(getter("file_id"), str), "file_id should be a string" + if has("filename"): + assert isinstance(getter("filename"), str), "filename should be a string" + if has("index"): + assert isinstance(getter("index"), int), "index should be an integer" + assert getter("index") >= 0, "index should be >= 0" + + elif expected_type == "url_citation": + # url_citation: url, title, start_index, end_index + if has("url"): + assert isinstance(getter("url"), str), "url should be a string" + if has("title"): + assert isinstance(getter("title"), str), "title should be a string" + if has("start_index") and has("end_index"): + assert isinstance(getter("start_index"), int), "start_index should be an integer" + assert isinstance(getter("end_index"), int), "end_index should be an integer" + assert getter("end_index") > getter("start_index"), "end_index should be > start_index" + + + elif expected_type == "container_file_citation": + # container_file_citation: container_id, file_id, filename, start_index, end_index + if has("container_id"): + assert isinstance(getter("container_id"), str), "container_id should be a string" + if has("file_id"): + assert isinstance(getter("file_id"), str), "file_id should be a string" + if has("filename"): + assert isinstance(getter("filename"), str), "filename should be a string" + + + elif expected_type == "file_path": + if has("file_id"): + assert isinstance(getter("file_id"), str), "file_id should be a string" + if has("index"): + assert isinstance(getter("index"), int), "index should be an integer" + assert getter("index") >= 0, "index should be >= 0" + + # Check for char_location (Anthropic native type that may come through) + elif expected_type == "char_location": + if has("start_char_index"): + assert isinstance(getter("start_char_index"), int), "start_char_index should be an integer" + if has("end_char_index"): + assert isinstance(getter("end_char_index"), int), "end_char_index should be an integer" + + +def collect_anthropic_streaming_citations( + stream, + timeout: int = 30 +) -> tuple[str, list, int]: + """ + Collect text content and citations from an Anthropic streaming response. + + Args: + stream: Anthropic streaming response iterator + timeout: Maximum time to collect (seconds) + + Returns: + Tuple of (complete_text, citations_list, chunk_count) + """ + import time + start_time = time.time() + + text_parts = [] + citations = [] + chunk_count = 0 + + for event in stream: + chunk_count += 1 + + # Check timeout + if time.time() - start_time > timeout: + break + + if hasattr(event, "type"): + event_type = event.type + + # Handle content_block_delta events + if event_type == "content_block_delta": + if hasattr(event, "delta") and event.delta: + # Check for text delta + if hasattr(event.delta, "type"): + if event.delta.type == "text_delta": + if hasattr(event.delta, "text"): + text_parts.append(str(event.delta.text)) + + # Check for citations delta + elif event.delta.type == "citations_delta": + if hasattr(event.delta, "citation"): + citations.append(event.delta.citation) + + # Safety check + if chunk_count > 2000: + break + + complete_text = "".join(text_parts) + return complete_text, citations, chunk_count + + +def collect_openai_streaming_annotations( + stream, + timeout: int = 30 +) -> tuple[str, list, int]: + """ + Collect text content and annotations from OpenAI Responses API streaming. + + Args: + stream: OpenAI Responses API streaming response iterator + timeout: Maximum time to collect (seconds) + + Returns: + Tuple of (complete_text, annotations_list, chunk_count) + """ + import time + start_time = time.time() + + text_parts = [] + annotations = [] + chunk_count = 0 + + for chunk in stream: + chunk_count += 1 + + # Check timeout + if time.time() - start_time > timeout: + break + + if hasattr(chunk, "type"): + chunk_type = chunk.type + + # Handle text delta + if chunk_type == "response.output_text.delta": + if hasattr(chunk, "delta"): + text_parts.append(chunk.delta) + + # Handle annotation added + elif chunk_type == "response.output_text.annotation.added": + if hasattr(chunk, "annotation"): + annotations.append(chunk.annotation) + + # Safety check + if chunk_count > 5000: + break + complete_text = "".join(text_parts) + return complete_text, annotations, chunk_count + + +# ============================================================================ +# WEB SEARCH VALIDATION HELPERS +# ============================================================================ + +def assert_valid_web_search_citation(citation, sdk_type="anthropic"): + """ + Validate web search citation structure. + + Args: + citation: Citation object to validate + sdk_type: Either "anthropic" or "openai" + """ + if sdk_type == "anthropic": + assert hasattr(citation, "type"), "Citation should have type" + assert citation.type == "web_search_result_location", f"Expected web_search_result_location, got {citation.type}" + assert hasattr(citation, "url") and citation.url, "Citation should have non-empty URL" + assert hasattr(citation, "title") and citation.title, "Citation should have non-empty title" + assert hasattr(citation, "encrypted_index"), "Citation should have encrypted_index" + assert hasattr(citation, "cited_text"), "Citation should have cited_text" + if citation.cited_text: + assert len(citation.cited_text) <= 150, f"cited_text should be <= 150 chars, got {len(citation.cited_text)}" + elif sdk_type == "openai": + assert hasattr(citation, "type"), "Annotation should have type" + assert citation.type == "url_citation", f"Expected url_citation, got {citation.type}" + assert hasattr(citation, "url") and citation.url, "Annotation should have non-empty URL" + assert hasattr(citation, "title"), "Annotation should have title" + else: + raise ValueError(f"Unknown sdk_type: {sdk_type}") + + +def assert_web_search_sources_valid(sources): + """ + Validate web search sources structure. + + Args: + sources: List of source objects to validate + """ + assert sources is not None, "Sources should not be None" + assert len(sources) > 0, "Sources should not be empty" + + for i, source in enumerate(sources): + assert hasattr(source, "url"), f"Source {i} should have url" + assert source.url, f"Source {i} url should not be empty" + assert hasattr(source, "title"), f"Source {i} should have title" + # encrypted_content and page_age are optional + + +def extract_domain(url: str) -> str: + """ + Extract domain from URL for validation. + + Args: + url: Full URL string + + Returns: + Domain string (e.g., "en.wikipedia.org") + """ + from urllib.parse import urlparse + parsed = urlparse(url) + return parsed.netloc.lower() + + +def validate_domain_filter(sources, allowed=None, blocked=None): + """ + Validate sources respect domain filters. + + Args: + sources: List of source objects with url attribute + allowed: List of allowed domain patterns (optional) + blocked: List of blocked domain patterns (optional) + """ + for source in sources: + if not hasattr(source, "url"): + continue + + domain = extract_domain(source.url) + + if allowed: + # Check if domain matches any allowed pattern + matches_allowed = False + for allowed_pattern in allowed: + # Handle subdomains: example.com should match docs.example.com + if domain == allowed_pattern or domain.endswith('.' + allowed_pattern): + matches_allowed = True + break + # Handle subdomain pattern: docs.example.com matches exactly + if allowed_pattern == domain: + matches_allowed = True + break + + assert matches_allowed, f"Domain {domain} not in allowed domains {allowed}" + + if blocked: + # Check if domain matches any blocked pattern + for blocked_pattern in blocked: + is_blocked = (domain == blocked_pattern or + domain.endswith('.' + blocked_pattern)) + assert not is_blocked, f"Domain {domain} should be blocked by {blocked_pattern}" \ No newline at end of file diff --git a/tests/integrations/typescript/src/utils/common.ts b/tests/integrations/typescript/src/utils/common.ts index 186e676b3..63a37f58b 100644 --- a/tests/integrations/typescript/src/utils/common.ts +++ b/tests/integrations/typescript/src/utils/common.ts @@ -518,7 +518,7 @@ export function assertValidChatResponse(response: unknown): void { // OpenAI-style response if (obj.choices) { expect(Array.isArray(obj.choices)).toBe(true) - expect(obj.choices.length).toBeGreaterThan(0) + expect((obj.choices as unknown[]).length).toBeGreaterThan(0) const choice = (obj.choices as Array>)[0] expect(choice.message).toBeDefined() @@ -947,3 +947,278 @@ export function createBatchInlineRequests( }, })) } + +// ============================================================================ +// Citations Test Data and Utilities +// ============================================================================ + +// Test document content for citations +export const CITATION_TEXT_DOCUMENT = `The Theory of Relativity was developed by Albert Einstein in the early 20th century. +It consists of two parts: Special Relativity published in 1905, and General Relativity published in 1915. + +Special Relativity deals with objects moving at constant velocities and introduced the famous equation E=mc². +General Relativity extends this to accelerating objects and provides a new understanding of gravity. + +Einstein's work revolutionized our understanding of space, time, and gravity, and its predictions have been +confirmed by numerous experiments and observations over the past century.` + +// Multiple documents for testing document_index +export const CITATION_MULTI_DOCUMENT_SET = [ + { + title: 'Physics Document', + content: `Quantum mechanics is a fundamental theory in physics that describes the behavior of matter and energy at the atomic and subatomic level. +It was developed in the early 20th century by physicists including Max Planck, Albert Einstein, Niels Bohr, and Werner Heisenberg.`, + }, + { + title: 'Chemistry Document', + content: `The periodic table organizes chemical elements by their atomic number, electron configuration, and recurring chemical properties. +It was first published by Dmitri Mendeleev in 1869 and has become a fundamental tool in chemistry.`, + }, +] + +// Citation types +export interface CharLocationCitation { + type: 'char_location' + cited_text: string + document_index: number + document_title?: string + start_char_index: number + end_char_index: number +} + +export interface PageLocationCitation { + type: 'page_location' + cited_text: string + document_index: number + document_title?: string + start_page_number: number + end_page_number: number +} + +export interface WebSearchCitation { + type: 'web_search_result_location' + url: string + title: string + encrypted_index: string + cited_text?: string +} + +export type AnthropicCitation = CharLocationCitation | PageLocationCitation | WebSearchCitation + +// Document block interface +export interface AnthropicDocument { + type: 'document' + title: string + source: { + type: 'text' | 'base64' + media_type: string + data: string + } + citations: { + enabled: boolean + } +} + +/** + * Create a properly formatted document block for Anthropic API with citations. + */ +export function createAnthropicDocument( + content: string, + docType: 'text' | 'pdf' | 'base64', + title: string = 'Test Document', + citationsEnabled: boolean = true +): AnthropicDocument { + const document: AnthropicDocument = { + type: 'document', + title, + source: { + type: docType === 'text' ? 'text' : 'base64', + media_type: docType === 'text' ? 'text/plain' : 'application/pdf', + data: content, + }, + citations: { + enabled: citationsEnabled, + }, + } + + return document +} + +/** + * Validate citation indices based on type. + */ +export function validateCitationIndices( + citation: AnthropicCitation, + citationType: 'char_location' | 'page_location' | 'web_search_result_location' +): void { + if (citationType === 'char_location') { + const charCitation = citation as CharLocationCitation + expect(charCitation.start_char_index).toBeDefined() + expect(charCitation.end_char_index).toBeDefined() + expect(charCitation.start_char_index).toBeGreaterThanOrEqual(0) + expect(charCitation.end_char_index).toBeGreaterThan(charCitation.start_char_index) + } else if (citationType === 'page_location') { + const pageCitation = citation as PageLocationCitation + expect(pageCitation.start_page_number).toBeDefined() + expect(pageCitation.end_page_number).toBeDefined() + expect(pageCitation.start_page_number).toBeGreaterThanOrEqual(1) + expect(pageCitation.end_page_number).toBeGreaterThan(pageCitation.start_page_number) + } else if (citationType === 'web_search_result_location') { + const webCitation = citation as WebSearchCitation + expect(webCitation.url).toBeDefined() + expect(webCitation.title).toBeDefined() + expect(webCitation.encrypted_index).toBeDefined() + } +} + +/** + * Assert that an Anthropic citation is valid and matches expected structure. + */ +export function assertValidAnthropicCitation( + citation: AnthropicCitation, + expectedType: 'char_location' | 'page_location' | 'web_search_result_location', + documentIndex: number = 0 +): void { + // Check basic structure + expect(citation.type).toBeDefined() + expect(citation.type).toBe(expectedType) + + // Check required fields + expect(citation.cited_text).toBeDefined() + expect(typeof citation.cited_text).toBe('string') + + if (expectedType !== 'web_search_result_location') { + expect(citation.cited_text?.length ?? 0).toBeGreaterThan(0) + + // Check document reference + expect((citation as CharLocationCitation | PageLocationCitation).document_index).toBeDefined() + expect((citation as CharLocationCitation | PageLocationCitation).document_index).toBe(documentIndex) + } + + // Validate type-specific indices + validateCitationIndices(citation, expectedType) +} + +/** + * Collect text content and citations from an Anthropic streaming response. + */ +export async function collectAnthropicStreamingCitations( + stream: AsyncIterable +): Promise<{ content: string; citations: AnthropicCitation[]; chunkCount: number }> { + let content = '' + const citations: AnthropicCitation[] = [] + let chunkCount = 0 + + for await (const event of stream) { + chunkCount++ + const eventObj = event as Record + + if (eventObj.type === 'content_block_delta') { + const delta = eventObj.delta as Record + + if (delta.type === 'text_delta' && delta.text) { + content += String(delta.text) + } else if (delta.type === 'citations_delta' && delta.citation) { + citations.push(delta.citation as AnthropicCitation) + } + } + } + + return { content, citations, chunkCount } +} + +/** + * Validate web search citation structure. + */ +export function assertValidWebSearchCitation( + citation: unknown, + sdkType: 'anthropic' | 'openai' = 'anthropic' +): void { + const citationObj = citation as Record + + if (sdkType === 'anthropic') { + expect(citationObj.type).toBeDefined() + expect(citationObj.type).toBe('web_search_result_location') + expect(citationObj.url).toBeDefined() + expect(typeof citationObj.url).toBe('string') + expect(citationObj.title).toBeDefined() + expect(typeof citationObj.title).toBe('string') + expect(citationObj.encrypted_index).toBeDefined() + + if (citationObj.cited_text) { + expect(typeof citationObj.cited_text).toBe('string') + expect((citationObj.cited_text as string).length).toBeLessThanOrEqual(150) + } + } else { + // OpenAI format (url_citation) + expect(citationObj.type).toBeDefined() + expect(citationObj.type).toBe('url_citation') + expect(citationObj.url).toBeDefined() + expect(typeof citationObj.url).toBe('string') + } +} + + +/** + * Assert that an OpenAI annotation is valid and matches expected structure. + */ +export function assertValidOpenAIAnnotation( + annotation: unknown, + expectedType: 'file_citation' | 'url_citation' | 'container_file_citation' | 'file_path' | 'char_location' = 'url_citation' +): void { + const annotationObj = annotation as Record + + expect(annotationObj.type).toBeDefined() + expect(annotationObj.type).toBe(expectedType) + + // Validate based on type + if (expectedType === 'file_citation') { + if (annotationObj.file_id) { + expect(typeof annotationObj.file_id).toBe('string') + } + if (annotationObj.filename) { + expect(typeof annotationObj.filename).toBe('string') + } + if (annotationObj.index !== undefined) { + expect(typeof annotationObj.index).toBe('number') + expect(annotationObj.index as number).toBeGreaterThanOrEqual(0) + } + } else if (expectedType === 'url_citation') { + if (annotationObj.url) { + expect(typeof annotationObj.url).toBe('string') + } + if (annotationObj.title) { + expect(typeof annotationObj.title).toBe('string') + } + if (annotationObj.start_index !== undefined && annotationObj.end_index !== undefined) { + expect(typeof annotationObj.start_index).toBe('number') + expect(typeof annotationObj.end_index).toBe('number') + expect(annotationObj.end_index as number).toBeGreaterThan(annotationObj.start_index as number) + } + } else if (expectedType === 'container_file_citation') { + if (annotationObj.container_id) { + expect(typeof annotationObj.container_id).toBe('string') + } + if (annotationObj.file_id) { + expect(typeof annotationObj.file_id).toBe('string') + } + if (annotationObj.filename) { + expect(typeof annotationObj.filename).toBe('string') + } + } else if (expectedType === 'file_path') { + if (annotationObj.file_id) { + expect(typeof annotationObj.file_id).toBe('string') + } + if (annotationObj.index !== undefined) { + expect(typeof annotationObj.index).toBe('number') + expect(annotationObj.index as number).toBeGreaterThanOrEqual(0) + } + } else if (expectedType === 'char_location') { + if (annotationObj.start_char_index !== undefined) { + expect(typeof annotationObj.start_char_index).toBe('number') + } + if (annotationObj.end_char_index !== undefined) { + expect(typeof annotationObj.end_char_index).toBe('number') + } + } +} diff --git a/tests/integrations/typescript/tests/test-anthropic.test.ts b/tests/integrations/typescript/tests/test-anthropic.test.ts index 611b4e00c..d601ebd6b 100644 --- a/tests/integrations/typescript/tests/test-anthropic.test.ts +++ b/tests/integrations/typescript/tests/test-anthropic.test.ts @@ -67,8 +67,14 @@ import { } from '../src/utils/config-loader' import { + assertValidAnthropicCitation, BASE64_IMAGE, CALCULATOR_TOOL, + CITATION_MULTI_DOCUMENT_SET, + CITATION_TEXT_DOCUMENT, + collectAnthropicStreamingCitations, + createAnthropicDocument, + FILE_DATA_BASE64, getApiKey, hasApiKey, IMAGE_URL, @@ -79,6 +85,7 @@ import { SINGLE_TOOL_CALL_MESSAGES, STREAMING_CHAT_MESSAGES, WEATHER_TOOL, + type AnthropicCitation, type ChatMessage, type ToolDefinition, } from '../src/utils/common' @@ -1625,4 +1632,563 @@ describe('Anthropic SDK Integration Tests', () => { } }) }) + + // ============================================================================ + // Citations Tests + // ============================================================================ + + describe('Citations - PDF Document', () => { + it('should return PDF citations with page_location', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + console.log('\n=== Testing PDF Citations (page_location) ===') + + // Create PDF document using helper + const document = createAnthropicDocument( + FILE_DATA_BASE64, + 'pdf', + 'Test PDF Document' + ) + + try { + const response = await client.messages.create({ + model, + max_tokens: 500, + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'What does this PDF document say? Please cite your sources.', + }, + document as never, + ], + }, + ], + } as never) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + expect(response.content.length).toBeGreaterThan(0) + + // Check for citations + let hasCitations = false + let citationCount = 0 + + for (const block of response.content) { + if ((block as unknown as { citations?: AnthropicCitation[] }).citations) { + hasCitations = true + const citations = (block as unknown as { citations: AnthropicCitation[] }).citations + + for (const citation of citations) { + citationCount++ + assertValidAnthropicCitation(citation, 'page_location', 0) + + const pageCitation = citation as { start_page_number: number; end_page_number: number; cited_text: string } + console.log( + `✓ Citation ${citationCount}: pages ${pageCitation.start_page_number}-${pageCitation.end_page_number}, ` + + `text: '${pageCitation.cited_text.substring(0, 50)}...'` + ) + } + } + } + + expect(hasCitations).toBe(true) + console.log(`✓ PDF citations test passed - Found ${citationCount} citations`) + } catch (error) { + console.log(`⚠️ PDF citations test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Citations - Text Document', () => { + it('should return text citations with char_location', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + console.log('\n=== Testing Text Citations (char_location) ===') + + // Create text document using helper + const document = createAnthropicDocument( + CITATION_TEXT_DOCUMENT, + 'text', + 'Theory of Relativity Overview' + ) + + try { + const response = await client.messages.create({ + model, + max_tokens: 500, + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'When was General Relativity published and what does it deal with? Please cite your sources.', + }, + document as never, + ], + }, + ], + } as never) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + expect(response.content.length).toBeGreaterThan(0) + + // Check for citations + let hasCitations = false + let citationCount = 0 + + for (const block of response.content) { + if ((block as unknown as { citations?: AnthropicCitation[] }).citations) { + hasCitations = true + const citations = (block as unknown as { citations: AnthropicCitation[] }).citations + + for (const citation of citations) { + citationCount++ + assertValidAnthropicCitation(citation, 'char_location', 0) + + const charCitation = citation as { start_char_index: number; end_char_index: number; cited_text: string } + console.log( + `✓ Citation ${citationCount}: chars ${charCitation.start_char_index}-${charCitation.end_char_index}, ` + + `text: '${charCitation.cited_text.substring(0, 50)}...'` + ) + } + } + } + + expect(hasCitations).toBe(true) + console.log(`✓ Text citations test passed - Found ${citationCount} citations`) + } catch (error) { + console.log(`⚠️ Text citations test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Citations - Multi Document', () => { + it('should return citations from multiple documents with document_index validation', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + console.log('\n=== Testing Multi-Document Citations ===') + + // Create multiple documents using helper + const documents = CITATION_MULTI_DOCUMENT_SET.map((docInfo) => + createAnthropicDocument(docInfo.content, 'text', docInfo.title) + ) + + try { + const response = await client.messages.create({ + model, + max_tokens: 600, + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Summarize what each document says. Please cite your sources from each document.', + }, + ...(documents as never[]), + ], + }, + ], + } as never) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + expect(response.content.length).toBeGreaterThan(0) + + // Check for citations from multiple documents + let hasCitations = false + const citationsByDoc: Record = { 0: 0, 1: 0 } + let totalCitations = 0 + + for (const block of response.content) { + if ((block as unknown as { citations?: AnthropicCitation[] }).citations) { + hasCitations = true + const citations = (block as unknown as { citations: AnthropicCitation[] }).citations + + for (const citation of citations) { + totalCitations++ + const docIdx = (citation as { document_index: number }).document_index || 0 + + // Validate citation + assertValidAnthropicCitation(citation, 'char_location', docIdx) + + // Track which document this citation is from + if (docIdx in citationsByDoc) { + citationsByDoc[docIdx]++ + } + + const charCitation = citation as { document_index: number; document_title?: string; start_char_index: number; end_char_index: number; cited_text: string } + const docTitle = charCitation.document_title || 'Unknown' + console.log( + `✓ Citation from doc[${docIdx}] (${docTitle}): ` + + `chars ${charCitation.start_char_index}-${charCitation.end_char_index}, ` + + `text: '${charCitation.cited_text.substring(0, 40)}...'` + ) + } + } + } + + expect(hasCitations).toBe(true) + + // Report statistics + console.log(`\n✓ Multi-document citations test passed:`) + console.log(` - Total citations: ${totalCitations}`) + for (const [docIdx, count] of Object.entries(citationsByDoc)) { + const docTitle = CITATION_MULTI_DOCUMENT_SET[Number(docIdx)].title + console.log(` - Document ${docIdx} (${docTitle}): ${count} citations`) + } + } catch (error) { + console.log(`⚠️ Multi-document citations test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Citations - Streaming Text', () => { + it('should stream text citations with citations_delta', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + console.log('\n=== Testing Streaming Citations (char_location) ===') + + // Create text document using helper + const document = createAnthropicDocument( + CITATION_TEXT_DOCUMENT, + 'text', + 'Machine Learning Introduction' + ) + + try { + const stream = client.messages.stream({ + model, + max_tokens: 500, + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Explain the key concepts from this document. Please cite your sources.', + }, + document as never, + ], + }, + ], + } as never) + + // Collect streaming content and citations using helper + const { content, citations, chunkCount } = await collectAnthropicStreamingCitations(stream) + + // Validate results + expect(chunkCount).toBeGreaterThan(0) + expect(content.length).toBeGreaterThan(0) + expect(citations.length).toBeGreaterThan(0) + + // Validate each citation + citations.forEach((citation, idx) => { + assertValidAnthropicCitation(citation, 'char_location', 0) + + const charCitation = citation as { start_char_index: number; end_char_index: number; cited_text: string } + console.log( + `✓ Citation ${idx + 1}: chars ${charCitation.start_char_index}-${charCitation.end_char_index}, ` + + `text: '${charCitation.cited_text.substring(0, 50)}...'` + ) + }) + + console.log(`✓ Streaming citations test passed - ${citations.length} citations in ${chunkCount} chunks`) + } catch (error) { + console.log(`⚠️ Streaming citations test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Citations - Streaming PDF', () => { + it('should stream PDF citations with page_location', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + console.log('\n=== Testing Streaming PDF Citations (page_location) ===') + + // Create PDF document using helper + const document = createAnthropicDocument( + FILE_DATA_BASE64, + 'pdf', + 'Test PDF Document' + ) + + try { + const stream = client.messages.stream({ + model, + max_tokens: 500, + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'What does this PDF say? Please cite your sources.', + }, + document as never, + ], + }, + ], + } as never) + + // Collect streaming content and citations using helper + const { content, citations, chunkCount } = await collectAnthropicStreamingCitations(stream) + + // Validate results + expect(chunkCount).toBeGreaterThan(0) + expect(content.length).toBeGreaterThan(0) + expect(citations.length).toBeGreaterThan(0) + + // Validate each citation - should be page_location for PDF + citations.forEach((citation, idx) => { + assertValidAnthropicCitation(citation, 'page_location', 0) + + const pageCitation = citation as { start_page_number: number; end_page_number: number; cited_text: string } + console.log( + `✓ Citation ${idx + 1}: pages ${pageCitation.start_page_number}-${pageCitation.end_page_number}, ` + + `text: '${pageCitation.cited_text.substring(0, 50)}...'` + ) + }) + + console.log(`✓ Streaming PDF citations test passed - ${citations.length} citations in ${chunkCount} chunks`) + } catch (error) { + console.log(`⚠️ Streaming PDF citations test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Web Search Tests + // ============================================================================ + + describe('Web Search - Non Streaming', () => { + it('should perform web search and return citations', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + console.log('\n=== Testing Web Search (Non-Streaming) ===') + + // Create web search tool + const webSearchTool = { + type: 'web_search_20250305', + name: 'web_search', + max_uses: 5, + } + + try { + const response = await client.messages.create({ + model, + max_tokens: 2048, + messages: [ + { + role: 'user', + content: 'What is a positive news story from today?', + }, + ], + tools: [webSearchTool] as never[], + } as never) + + // Validate basic response + expect(response).toBeDefined() + expect(response.content).toBeDefined() + expect(response.content.length).toBeGreaterThan(0) + + // Check for web search tool use + let hasWebSearch = false + let hasSearchResults = false + let hasCitations = false + let searchQuery: string | null = null + + for (const block of response.content) { + const blockObj = block as unknown as Record + + if (blockObj.type === 'server_tool_use' && (blockObj as { name?: string }).name === 'web_search') { + hasWebSearch = true + const input = (blockObj as { input?: Record }).input + if (input && input.query) { + searchQuery = String(input.query) + console.log(`✓ Found web search with query: ${searchQuery}`) + } + } else if (blockObj.type === 'web_search_tool_result') { + hasSearchResults = true + const content = (blockObj as { content?: unknown[] }).content + if (content && Array.isArray(content)) { + console.log(`✓ Found ${content.length} search results`) + + // Log first few results + content.slice(0, 3).forEach((result, i) => { + const resultObj = result as { url?: string; title?: string } + if (resultObj.url && resultObj.title) { + console.log(` Result ${i + 1}: ${resultObj.title}`) + } + }) + } + } else if (blockObj.type === 'text') { + const citations = (blockObj as { citations?: unknown[] }).citations + if (citations && citations.length > 0) { + hasCitations = true + console.log(`✓ Found ${citations.length} citations in response`) + + // Validate citation structure + citations.slice(0, 3).forEach((citation) => { + const citationObj = citation as Record + expect(citationObj.type).toBeDefined() + expect(citationObj.url).toBeDefined() + expect(citationObj.title).toBeDefined() + expect(citationObj.cited_text).toBeDefined() + console.log(` Citation: ${citationObj.title}`) + }) + } + } + } + + // Validate that web search was performed + expect(hasWebSearch).toBe(true) + expect(hasSearchResults).toBe(true) + expect(searchQuery).not.toBeNull() + + console.log('✓ Web search (non-streaming) test passed!') + } catch (error) { + console.log(`⚠️ Web search non-streaming test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Web Search - Streaming', () => { + it('should stream web search results', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + console.log('\n=== Testing Web Search (Streaming) ===') + + // Create web search tool with user location + const webSearchTool = { + type: 'web_search_20250305', + name: 'web_search', + max_uses: 5, + user_location: { + type: 'approximate', + city: 'New York', + region: 'New York', + country: 'US', + timezone: 'America/New_York', + }, + } + + try { + const stream = client.messages.stream({ + model, + max_tokens: 2048, + messages: [ + { + role: 'user', + content: 'What are the latest advancements in renewable energy?', + }, + ], + tools: [webSearchTool] as never[], + } as never) + + let hasWebSearch = false + let hasSearchResults = false + let hasTextContent = false + let chunkCount = 0 + let searchQuery: string | null = null + const searchResults: unknown[] = [] + + for await (const event of stream) { + chunkCount++ + const eventObj = event as unknown as Record + + // Check for web search tool use in content block start + if (eventObj.type === 'content_block_start') { + const contentBlock = eventObj.content_block as Record + if (contentBlock.type === 'server_tool_use' && (contentBlock as { name?: string }).name === 'web_search') { + hasWebSearch = true + console.log('✓ Web search tool invoked') + } + } + + // Check for web search input delta + if (eventObj.type === 'content_block_delta') { + const delta = eventObj.delta as Record + if (delta.type === 'input_json_delta' && delta.partial_json) { + try { + const parsed = JSON.parse(String(delta.partial_json)) + if (parsed.query && !searchQuery) { + searchQuery = parsed.query + console.log(`✓ Search query: ${searchQuery}`) + } + } catch { + // Partial JSON may not be complete yet + } + } + } + + // Check for web search results + if (eventObj.type === 'content_block_start') { + const contentBlock = eventObj.content_block as Record + if (contentBlock.type === 'web_search_tool_result') { + hasSearchResults = true + const content = (contentBlock as { content?: unknown[] }).content + if (content && Array.isArray(content)) { + searchResults.push(...content) + } + } + } + + // Check for text content delta + if (eventObj.type === 'content_block_delta') { + const delta = eventObj.delta as Record + if (delta.type === 'text_delta') { + hasTextContent = true + } + } + } + + // Validate results + expect(chunkCount).toBeGreaterThan(0) + expect(hasWebSearch).toBe(true) + expect(hasSearchResults).toBe(true) + expect(hasTextContent).toBe(true) + + if (searchResults.length > 0) { + console.log(`✓ Received ${searchResults.length} search results`) + searchResults.slice(0, 3).forEach((result, i) => { + const resultObj = result as { url?: string; title?: string } + if (resultObj.url && resultObj.title) { + console.log(` Result ${i + 1}: ${resultObj.title}`) + } + }) + } + + console.log(`✓ Web search (streaming) test passed! (${chunkCount} chunks)`) + } catch (error) { + console.log(`⚠️ Web search streaming test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) }) diff --git a/tests/integrations/typescript/tests/test-openai.test.ts b/tests/integrations/typescript/tests/test-openai.test.ts index 72bd7caba..6decfee42 100644 --- a/tests/integrations/typescript/tests/test-openai.test.ts +++ b/tests/integrations/typescript/tests/test-openai.test.ts @@ -72,6 +72,7 @@ import { } from '../src/utils/config-loader' import { + assertValidOpenAIAnnotation, CALCULATOR_TOOL, EMBEDDINGS_MULTIPLE_TEXTS, EMBEDDINGS_SIMILAR_TEXTS, @@ -1773,4 +1774,269 @@ describe('OpenAI SDK Integration Tests', () => { } ) }) + + // ============================================================================ + // Web Search Tests (Responses API) + // ============================================================================ + + describe('Web Search - Annotation Conversion', () => { + const testCases = getCrossProviderParamsWithVkForScenario('web_search') + + it.each(testCases)( + 'should convert citations to annotations - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for web_search') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + console.log(`\n=== Testing Web Search Annotation Conversion for provider ${provider} ===`) + + try { + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + const response = await responses.create({ + model: formatProviderModel(provider, model), + tools: [{ type: 'web_search' }], + input: 'What is quantum computing use web search tool?', + max_output_tokens: 1200, + }) as { output?: Array<{ type?: string; content?: Array<{ type?: string; text?: string; annotations?: unknown[] }> }> } + + // Validate basic response + expect(response).toBeDefined() + expect(response.output).toBeDefined() + expect(response.output!.length).toBeGreaterThan(0) + + // Check for annotations in message content + let hasAnnotations = false + const annotations: unknown[] = [] + + for (const outputItem of response.output || []) { + if (outputItem.type === 'message' && outputItem.content) { + for (const contentItem of outputItem.content) { + if (contentItem.type === 'text' && contentItem.annotations) { + hasAnnotations = true + annotations.push(...contentItem.annotations) + } + } + } + } + + if (hasAnnotations) { + console.log(`✓ Found ${annotations.length} annotations`) + + // Validate annotation structure + annotations.slice(0, 3).forEach((annotation) => { + assertValidOpenAIAnnotation(annotation, 'url_citation') + const annotationObj = annotation as { url?: string; title?: string } + if (annotationObj.title) { + console.log(` Annotation: ${annotationObj.title}`) + } + }) + } else { + console.log('⚠ No annotations found') + } + + console.log('✓ Annotation conversion test passed!') + } catch (error) { + console.log(`⚠️ Web search annotation conversion test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Web Search - User Location', () => { + const testCases = getCrossProviderParamsWithVkForScenario('web_search') + + it.each(testCases)( + 'should use user location for localized results - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for web_search') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + console.log(`\n=== Testing Web Search with User Location for provider ${provider} ===`) + + try { + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + const response = await responses.create({ + model: formatProviderModel(provider, model), + tools: [{ + type: 'web_search', + user_location: { + type: 'approximate', + city: 'San Francisco', + region: 'California', + country: 'US', + timezone: 'America/Los_Angeles', + }, + }], + input: 'What is the weather like today?', + max_output_tokens: 1200, + }) as { output?: Array<{ type?: string }> } + + // Validate basic response + expect(response).toBeDefined() + expect(response.output).toBeDefined() + expect(response.output!.length).toBeGreaterThan(0) + + // Check for web_search_call with status + let hasWebSearch = false + let hasMessage = false + + for (const outputItem of response.output || []) { + if (outputItem.type === 'web_search_call') { + hasWebSearch = true + console.log('✓ Web search executed') + } else if (outputItem.type === 'message') { + hasMessage = true + } + } + + expect(hasWebSearch).toBe(true) + expect(hasMessage).toBe(true) + + console.log('✓ User location test passed!') + } catch (error) { + console.log(`⚠️ Web search user location test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Web Search - Wildcard Domains', () => { + const testCases = getCrossProviderParamsWithVkForScenario('web_search') + + it.each(testCases)( + 'should filter results with wildcard domain patterns - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for web_search') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + console.log(`\n=== Testing Web Search with Wildcard Domains for provider ${provider} ===`) + + try { + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + const response = await responses.create({ + model: formatProviderModel(provider, model), + tools: [{ + type: 'web_search', + allowed_domains: ['wikipedia.org/*', '*.edu'], + }], + input: 'What is machine learning use web search tool?', + include: ['web_search_call.action.sources'], + max_output_tokens: 1500, + }) as { output?: Array<{ type?: string; action?: { sources?: unknown[] } }> } + + // Validate basic response + expect(response).toBeDefined() + expect(response.output).toBeDefined() + + // Collect search sources + const searchSources: unknown[] = [] + for (const outputItem of response.output || []) { + if (outputItem.type === 'web_search_call' && outputItem.action?.sources) { + searchSources.push(...outputItem.action.sources) + } + } + + if (searchSources.length > 0) { + console.log(`✓ Found ${searchSources.length} search sources`) + searchSources.slice(0, 3).forEach((source, i) => { + const sourceObj = source as { url?: string } + if (sourceObj.url) { + console.log(` Source ${i + 1}: ${sourceObj.url}`) + } + }) + } + + console.log('✓ Wildcard domains test passed!') + } catch (error) { + console.log(`⚠️ Web search wildcard domains test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Web Search - Multi Turn', () => { + const testCases = getCrossProviderParamsWithVkForScenario('web_search') + + it.each(testCases)( + 'should handle multi-turn conversation with web search - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for web_search') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + console.log(`\n=== Testing Web Search Multi-Turn (OpenAI SDK) for provider ${provider} ===`) + + try { + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + // First turn + const inputMessages: unknown[] = [ + { role: 'user', content: 'What is renewable energy use web search tool?' }, + ] + + const response1 = await responses.create({ + model: formatProviderModel(provider, model), + tools: [{ type: 'web_search' }], + input: inputMessages, + max_output_tokens: 1500, + }) as { output?: unknown[] } + + expect(response1).toBeDefined() + expect(response1.output).toBeDefined() + + console.log(`✓ First turn completed with ${response1.output!.length} output items`) + + // Second turn with follow-up + // Add each output item from the first response + for (const outputItem of response1.output || []) { + inputMessages.push(outputItem) + } + inputMessages.push({ role: 'user', content: 'What are the main types of renewable energy?' }) + + const response2 = await responses.create({ + model: formatProviderModel(provider, model), + tools: [{ type: 'web_search' }], + input: inputMessages, + max_output_tokens: 1500, + }) as { output?: Array<{ type?: string }> } + + expect(response2).toBeDefined() + expect(response2.output).toBeDefined() + expect(response2.output!.length).toBeGreaterThan(0) + + // Validate second turn has message response + let hasMessage = false + for (const outputItem of response2.output || []) { + if (outputItem.type === 'message') { + hasMessage = true + } + } + + expect(hasMessage).toBe(true) + console.log(`✓ Second turn completed with ${response2.output!.length} output items`) + console.log('✓ Multi-turn conversation test passed!') + } catch (error) { + console.log(`⚠️ Web search multi-turn test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) }) diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index ba706f7c2..ce71a1b20 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -80,7 +80,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { return resp.ExtraFields.RawResponse, nil } } - return anthropic.ToAnthropicResponsesResponse(resp), nil + return anthropic.ToAnthropicResponsesResponse(ctx, resp), nil }, ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err)