diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 58a0ef7d5..f8558118c 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -40,6 +40,7 @@ type TestScenarios struct { TranscriptionStream bool // Streaming speech-to-text functionality Embedding bool // Embedding functionality Reasoning bool // Reasoning/thinking functionality via Responses API + PromptCaching bool // Prompt caching functionality ListModels bool // List available models functionality BatchCreate bool // Batch API create functionality BatchList bool // Batch API list functionality @@ -605,6 +606,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageBase64: true, MultipleImages: true, CompleteEnd2End: true, + PromptCaching: true, SpeechSynthesis: false, // Not supported SpeechSynthesisStream: false, // Not supported Transcription: false, // Not supported @@ -638,6 +640,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageBase64: true, MultipleImages: true, CompleteEnd2End: true, + PromptCaching: true, SpeechSynthesis: false, // Not supported SpeechSynthesisStream: false, // Not supported Transcription: false, // Not supported diff --git a/core/internal/testutil/prompt_caching.go b/core/internal/testutil/prompt_caching.go index 733ef0e3d..00e02097c 100644 --- a/core/internal/testutil/prompt_caching.go +++ b/core/internal/testutil/prompt_caching.go @@ -262,6 +262,9 @@ func GetPromptCachingTools() []schemas.ChatTool { Required: []string{}, }, }, + CacheControl: &schemas.CacheControl{ + Type: schemas.CacheControlTypeEphemeral, + }, }, } } @@ -271,16 +274,14 @@ func GetPromptCachingTools() []schemas.ChatTool { // by making multiple requests with the same long prefix and tools, and verifying // that cached tokens increase in subsequent requests. func RunPromptCachingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { - // Only run for OpenAI provider as prompt caching is OpenAI-specific - if testConfig.Provider != schemas.OpenAI { - t.Logf("Prompt caching test skipped for provider %s (OpenAI-specific feature)", testConfig.Provider) - return - } - if !testConfig.Scenarios.SimpleChat { t.Logf("Prompt caching test requires SimpleChat support") return } + if !testConfig.Scenarios.PromptCaching { + t.Logf("Prompt caching test not supported for provider %s", testConfig.Provider) + return + } t.Run("PromptCaching", func(t *testing.T) { if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { @@ -291,7 +292,15 @@ func RunPromptCachingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con systemMessage := schemas.ChatMessage{ Role: schemas.ChatMessageRoleSystem, Content: &schemas.ChatMessageContent{ - ContentStr: bifrost.Ptr(longSharedPrefix), + ContentBlocks: []schemas.ChatContentBlock{ + { + Type: schemas.ChatContentBlockTypeText, + Text: bifrost.Ptr(longSharedPrefix), + CacheControl: &schemas.CacheControl{ + Type: schemas.CacheControlTypeEphemeral, + }, + }, + }, }, } diff --git a/core/internal/testutil/test_retry_framework.go b/core/internal/testutil/test_retry_framework.go index 2d4d5750b..58ce9e3d0 100644 --- a/core/internal/testutil/test_retry_framework.go +++ b/core/internal/testutil/test_retry_framework.go @@ -829,7 +829,9 @@ func StreamingRetryConfig() TestRetryConfig { }, OnRetry: func(attempt int, reason string, t *testing.T) { // reason already contains ❌ prefix from retry logic - t.Logf("🔄 Retrying streaming test (attempt %d): %s", attempt, reason) + // attempt represents the current failed attempt number + // Log with attempt+1 to show the next attempt that will run + t.Logf("🔄 Retrying streaming test (attempt %d): %s", attempt+1, reason) }, OnFinalFail: func(attempts int, finalErr error, t *testing.T) { // finalErr already contains ❌ prefix from retry logic @@ -2148,6 +2150,13 @@ func WithResponsesStreamValidationRetry( for attempt := 1; attempt <= config.MaxAttempts; attempt++ { context.AttemptNumber = attempt + // Log attempt start (especially for retries) + if attempt > 1 { + t.Logf("🔄 Starting responses stream retry attempt %d/%d for %s", attempt, config.MaxAttempts, context.ScenarioName) + } else { + t.Logf("🔄 Starting responses stream test attempt %d/%d for %s", attempt, config.MaxAttempts, context.ScenarioName) + } + // Execute the operation to get the stream responseChannel, err := operation() @@ -2193,13 +2202,20 @@ func WithResponsesStreamValidationRetry( } if shouldRetry { + // Log the error and upcoming retry + if attempt > 1 { + t.Logf("❌ Responses stream request failed on attempt %d/%d for %s: %s", attempt, config.MaxAttempts, context.ScenarioName, retryReason) + } + if config.OnRetry != nil { + // Pass current failed attempt number config.OnRetry(attempt, retryReason, t) } else { t.Logf("🔄 Retrying responses stream request (attempt %d/%d) for %s: %s", attempt+1, config.MaxAttempts, context.ScenarioName, retryReason) } delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + t.Logf("⏳ Waiting %v before retry...", delay) time.Sleep(delay) continue } @@ -2225,12 +2241,17 @@ func WithResponsesStreamValidationRetry( if responseChannel == nil { if attempt < config.MaxAttempts { retryReason := "❌ response channel is nil" + t.Logf("❌ Responses stream response channel is nil on attempt %d/%d for %s", attempt, config.MaxAttempts, context.ScenarioName) if config.OnRetry != nil { + // Pass current failed attempt number config.OnRetry(attempt, retryReason, t) + } else { + t.Logf("🔄 Retrying responses stream request (attempt %d/%d) for %s: %s", attempt+1, config.MaxAttempts, context.ScenarioName, retryReason) } delay := calculateRetryDelay(attempt-1, config.BaseDelay, config.MaxDelay) + t.Logf("⏳ Waiting %v before retry...", delay) time.Sleep(delay) - continue + continue // CRITICAL: Must continue to retry, not return } return ResponsesStreamValidationResult{ Passed: false, @@ -2293,6 +2314,7 @@ func WithResponsesStreamValidationRetry( } if config.OnRetry != nil { + // Pass current failed attempt number config.OnRetry(attempt, retryReason, t) } else { t.Logf("🔄 Retrying responses stream validation (attempt %d/%d) for %s: %s", attempt+1, config.MaxAttempts, context.ScenarioName, retryReason) diff --git a/core/providers/anthropic/anthropic_test.go b/core/providers/anthropic/anthropic_test.go index 615434ff7..e1822dd0f 100644 --- a/core/providers/anthropic/anthropic_test.go +++ b/core/providers/anthropic/anthropic_test.go @@ -29,8 +29,9 @@ func TestAnthropic(t *testing.T) { {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, {Provider: schemas.Anthropic, Model: "claude-sonnet-4-20250514"}, }, - VisionModel: "claude-3-7-sonnet-20250219", // Same model supports vision - ReasoningModel: "claude-opus-4-5", + VisionModel: "claude-3-7-sonnet-20250219", // Same model supports vision + ReasoningModel: "claude-opus-4-5", + PromptCachingModel: "claude-sonnet-4-20250514", Scenarios: testutil.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, @@ -47,6 +48,7 @@ func TestAnthropic(t *testing.T) { CompleteEnd2End: true, Embedding: false, Reasoning: true, + PromptCaching: true, ListModels: true, BatchCreate: true, BatchList: true, diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index 6567fbfc7..b0a19cdd2 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -62,6 +62,10 @@ func ToAnthropicChatRequest(bifrostReq *schemas.BifrostChatRequest) (*AnthropicM } } + if tool.CacheControl != nil { + anthropicTool.CacheControl = tool.CacheControl + } + tools = append(tools, anthropicTool) } anthropicReq.Tools = tools @@ -143,8 +147,9 @@ func ToAnthropicChatRequest(bifrostReq *schemas.BifrostChatRequest) (*AnthropicM for _, block := range msg.Content.ContentBlocks { if block.Text != nil { blocks = append(blocks, AnthropicContentBlock{ - Type: "text", - Text: block.Text, + Type: AnthropicContentBlockTypeText, + Text: block.Text, + CacheControl: block.CacheControl, }) } } @@ -177,8 +182,9 @@ func ToAnthropicChatRequest(bifrostReq *schemas.BifrostChatRequest) (*AnthropicM for _, block := range toolMsg.Content.ContentBlocks { if block.Text != nil { blocks = append(blocks, AnthropicContentBlock{ - Type: "text", - Text: block.Text, + Type: AnthropicContentBlockTypeText, + Text: block.Text, + CacheControl: block.CacheControl, }) } else if block.ImageURLStruct != nil { blocks = append(blocks, ConvertToAnthropicImageBlock(block)) @@ -233,8 +239,9 @@ func ToAnthropicChatRequest(bifrostReq *schemas.BifrostChatRequest) (*AnthropicM for _, block := range msg.Content.ContentBlocks { if block.Text != nil { content = append(content, AnthropicContentBlock{ - Type: AnthropicContentBlockTypeText, - Text: block.Text, + Type: AnthropicContentBlockTypeText, + Text: block.Text, + CacheControl: block.CacheControl, }) } else if block.ImageURLStruct != nil { content = append(content, ConvertToAnthropicImageBlock(block)) @@ -310,63 +317,54 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse() *schemas.Bifro // Process content and tool calls if response.Content != nil { - if len(response.Content) == 1 && response.Content[0].Type == AnthropicContentBlockTypeText { - contentStr = response.Content[0].Text - } else { - for _, c := range response.Content { - switch c.Type { - case AnthropicContentBlockTypeText: - if c.Text != nil { - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeText, - Text: c.Text, - }) + for _, c := range response.Content { + switch c.Type { + case AnthropicContentBlockTypeText: + if c.Text != nil { + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: c.Text, + }) + } + case AnthropicContentBlockTypeToolUse: + if c.ID != nil && c.Name != nil { + function := schemas.ChatAssistantMessageToolCallFunction{ + Name: c.Name, } - case AnthropicContentBlockTypeToolUse: - if c.ID != nil && c.Name != nil { - function := schemas.ChatAssistantMessageToolCallFunction{ - Name: c.Name, - } - // Marshal the input to JSON string - if c.Input != nil { - args, err := json.Marshal(c.Input) - if err != nil { - function.Arguments = fmt.Sprintf("%v", c.Input) - } else { - function.Arguments = string(args) - } + // Marshal the input to JSON string + if c.Input != nil { + args, err := json.Marshal(c.Input) + if err != nil { + function.Arguments = fmt.Sprintf("%v", c.Input) } else { - function.Arguments = "{}" + function.Arguments = string(args) } - - toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ - Index: uint16(len(toolCalls)), - Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)), - ID: c.ID, - Function: function, - }) + } else { + function.Arguments = "{}" } - case AnthropicContentBlockTypeThinking: - reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{ - Index: len(reasoningDetails), - Type: schemas.BifrostReasoningDetailsTypeText, - Text: c.Thinking, - Signature: c.Signature, + + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)), + ID: c.ID, + Function: function, }) - if c.Thinking != nil { - reasoningText += *c.Thinking + "\n" - } + } + case AnthropicContentBlockTypeThinking: + reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{ + Index: len(reasoningDetails), + Type: schemas.BifrostReasoningDetailsTypeText, + Text: c.Thinking, + Signature: c.Signature, + }) + if c.Thinking != nil { + reasoningText += *c.Thinking + "\n" } } } } - if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText { - contentStr = contentBlocks[0].Text - contentBlocks = nil - } - // Create a single choice with the collected content // Create message content messageContent := schemas.ChatMessageContent{ diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index ec6516819..6b20e19e4 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1340,36 +1340,8 @@ func (request *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx context.Co // Convert messages directly to ChatMessage format var bifrostMessages []schemas.ResponsesMessage - // Handle system message - convert Anthropic system field to first message with role "system" - if request.System != nil { - var systemText string - if request.System.ContentStr != nil { - systemText = *request.System.ContentStr - } else if request.System.ContentBlocks != nil { - // Combine text blocks from system content - var textParts []string - for _, block := range request.System.ContentBlocks { - if block.Text != nil { - textParts = append(textParts, *block.Text) - } - } - systemText = strings.Join(textParts, "\n") - } - - if systemText != "" { - systemMsg := schemas.ResponsesMessage{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), - Content: &schemas.ResponsesMessageContent{ - ContentStr: &systemText, - }, - } - bifrostMessages = append(bifrostMessages, systemMsg) - } - } - // Convert regular messages using the new conversion method - convertedMessages := ConvertAnthropicMessagesToBifrostMessages(request.Messages, nil, false, provider == schemas.Bedrock) + convertedMessages := ConvertAnthropicMessagesToBifrostMessages(request.Messages, request.System, false, provider == schemas.Bedrock) bifrostMessages = append(bifrostMessages, convertedMessages...) // Convert tools if present @@ -2091,30 +2063,37 @@ func ConvertBifrostMessagesToAnthropicMessages(bifrostMessages []schemas.Respons // Helper function to convert Anthropic system content to Bifrost messages func convertAnthropicSystemToBifrostMessages(systemContent *AnthropicContent) []schemas.ResponsesMessage { - var systemText string - if systemContent.ContentStr != nil { - systemText = *systemContent.ContentStr + var bifrostMessages []schemas.ResponsesMessage + + if systemContent.ContentStr != nil && *systemContent.ContentStr != "" { + bifrostMessages = append(bifrostMessages, schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentStr: systemContent.ContentStr, + }, + }) } else if systemContent.ContentBlocks != nil { - // Combine text blocks from system content - var textParts []string + contentBlocks := []schemas.ResponsesMessageContentBlock{} for _, block := range systemContent.ContentBlocks { - if block.Text != nil { - textParts = append(textParts, *block.Text) + if block.Text != nil { // System messages will only have text content + contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: block.Text, + CacheControl: block.CacheControl, + }) } } - systemText = strings.Join(textParts, "\n") + if len(contentBlocks) > 0 { + bifrostMessages = append(bifrostMessages, schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: contentBlocks, + }, + }) + } } - if systemText != "" { - return []schemas.ResponsesMessage{{ - Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), - Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), - Content: &schemas.ResponsesMessageContent{ - ContentStr: &systemText, - }, - }} - } - return []schemas.ResponsesMessage{} + return bifrostMessages } // Helper function to convert a single Anthropic message to Bifrost messages @@ -2192,7 +2171,13 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, Content: &schemas.ResponsesMessageContent{ - ContentStr: block.Text, + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: block.Text, + CacheControl: block.CacheControl, + }, + }, }, } bifrostMessages = append(bifrostMessages, bifrostMsg) @@ -2275,8 +2260,9 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant blockType = schemas.ResponsesInputMessageContentBlockTypeText } toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ - Type: blockType, - Text: contentBlock.Text, + Type: blockType, + Text: contentBlock.Text, + CacheControl: contentBlock.CacheControl, }) } case AnthropicContentBlockTypeImage: @@ -2388,8 +2374,9 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC Content: &schemas.ResponsesMessageContent{ ContentBlocks: []schemas.ResponsesMessageContentBlock{ { - Type: schemas.ResponsesOutputMessageContentTypeText, - Text: block.Text, + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: block.Text, + CacheControl: block.CacheControl, }, }, }, @@ -2400,7 +2387,13 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, Content: &schemas.ResponsesMessageContent{ - ContentStr: block.Text, + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: block.Text, + CacheControl: block.CacheControl, + }, + }, }, } } @@ -2499,8 +2492,9 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC blockType = schemas.ResponsesInputMessageContentBlockTypeText } toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ - Type: blockType, - Text: contentBlock.Text, + Type: blockType, + Text: contentBlock.Text, + CacheControl: contentBlock.CacheControl, }) } case AnthropicContentBlockTypeImage: @@ -2563,8 +2557,9 @@ func convertAnthropicContentBlocksToResponsesMessages(contentBlocks []AnthropicC blockType = schemas.ResponsesInputMessageContentBlockTypeText } toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ - Type: blockType, - Text: contentBlock.Text, + Type: blockType, + Text: contentBlock.Text, + CacheControl: contentBlock.CacheControl, }) } } @@ -3084,6 +3079,10 @@ func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { } } + if tool.CacheControl != nil { + bifrostTool.CacheControl = tool.CacheControl + } + return bifrostTool } @@ -3260,6 +3259,10 @@ func convertBifrostToolToAnthropic(tool *schemas.ResponsesTool) *AnthropicTool { anthropicTool.InputSchema = tool.ResponsesToolFunction.Parameters } + if tool.CacheControl != nil { + anthropicTool.CacheControl = tool.CacheControl + } + return anthropicTool } @@ -3324,8 +3327,9 @@ func convertContentBlockToAnthropic(block schemas.ResponsesMessageContentBlock) case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: if block.Text != nil { return &AnthropicContentBlock{ - Type: AnthropicContentBlockTypeText, - Text: block.Text, + Type: AnthropicContentBlockTypeText, + Text: block.Text, + CacheControl: block.CacheControl, } } case schemas.ResponsesInputMessageContentBlockTypeImage: @@ -3336,6 +3340,7 @@ func convertContentBlockToAnthropic(block schemas.ResponsesMessageContentBlock) ImageURLStruct: &schemas.ChatInputImage{ URL: *block.ResponsesInputMessageContentBlockImage.ImageURL, }, + CacheControl: block.CacheControl, } anthropicBlock := ConvertToAnthropicImageBlock(chatBlock) return &anthropicBlock @@ -3374,6 +3379,7 @@ func (block AnthropicContentBlock) toBifrostResponsesImageBlock() schemas.Respon ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ ImageURL: schemas.Ptr(getImageURLFromBlock(block)), }, + CacheControl: block.CacheControl, } } diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 1011cc0fc..a59878e46 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -215,18 +215,19 @@ const ( // AnthropicContentBlock represents content in Anthropic message format type AnthropicContentBlock struct { - Type AnthropicContentBlockType `json:"type"` // "text", "image", "tool_use", "tool_result", "thinking" - Text *string `json:"text,omitempty"` // For text content - Thinking *string `json:"thinking,omitempty"` // For thinking content - Signature *string `json:"signature,omitempty"` // For signature content - Data *string `json:"data,omitempty"` // For data content (encrypted data for redacted thinking, signature does not come with this) - ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content - ID *string `json:"id,omitempty"` // For tool_use content - Name *string `json:"name,omitempty"` // For tool_use content - Input any `json:"input,omitempty"` // For tool_use content - ServerName *string `json:"server_name,omitempty"` // For mcp_tool_use content - Content *AnthropicContent `json:"content,omitempty"` // For tool_result content - Source *AnthropicImageSource `json:"source,omitempty"` // For image content + Type AnthropicContentBlockType `json:"type"` // "text", "image", "tool_use", "tool_result", "thinking" + Text *string `json:"text,omitempty"` // For text content + Thinking *string `json:"thinking,omitempty"` // For thinking content + Signature *string `json:"signature,omitempty"` // For signature content + Data *string `json:"data,omitempty"` // For data content (encrypted data for redacted thinking, signature does not come with this) + ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content + ID *string `json:"id,omitempty"` // For tool_use content + Name *string `json:"name,omitempty"` // For tool_use content + Input any `json:"input,omitempty"` // For tool_use content + ServerName *string `json:"server_name,omitempty"` // For mcp_tool_use content + Content *AnthropicContent `json:"content,omitempty"` // For tool_result content + Source *AnthropicImageSource `json:"source,omitempty"` // For image content + CacheControl *schemas.CacheControl `json:"cache_control,omitempty"` // For cache control content } // AnthropicImageSource represents image source in Anthropic format @@ -288,10 +289,11 @@ type AnthropicToolWebSearch struct { // AnthropicTool represents a tool in Anthropic format type AnthropicTool struct { - Name string `json:"name"` - Type *AnthropicToolType `json:"type,omitempty"` - Description *string `json:"description,omitempty"` - InputSchema *schemas.ToolFunctionParameters `json:"input_schema,omitempty"` + Name string `json:"name"` + Type *AnthropicToolType `json:"type,omitempty"` + Description *string `json:"description,omitempty"` + InputSchema *schemas.ToolFunctionParameters `json:"input_schema,omitempty"` + CacheControl *schemas.CacheControl `json:"cache_control,omitempty"` *AnthropicToolComputerUse *AnthropicToolWebSearch diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index ee116c9b3..f7925d4c9 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -95,8 +95,9 @@ func ConvertBifrostFinishReasonToAnthropic(bifrostReason string) AnthropicStopRe // Uses the same pattern as the original buildAnthropicImageSourceMap function func ConvertToAnthropicImageBlock(block schemas.ChatContentBlock) AnthropicContentBlock { imageBlock := AnthropicContentBlock{ - Type: "image", - Source: &AnthropicImageSource{}, + Type: AnthropicContentBlockTypeImage, + CacheControl: block.CacheControl, + Source: &AnthropicImageSource{}, } if block.ImageURLStruct == nil { diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index d8cc3d798..474a19bc7 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -70,10 +70,11 @@ func TestBedrock(t *testing.T) { {Provider: schemas.Bedrock, Model: "claude-4-sonnet"}, {Provider: schemas.Bedrock, Model: "claude-4.5-sonnet"}, }, - EmbeddingModel: "cohere.embed-v4:0", - ReasoningModel: "claude-4.5-sonnet", - BatchExtraParams: batchExtraParams, - FileExtraParams: fileExtraParams, + EmbeddingModel: "cohere.embed-v4:0", + ReasoningModel: "claude-4.5-sonnet", + PromptCachingModel: "claude-4.5-sonnet", + BatchExtraParams: batchExtraParams, + FileExtraParams: fileExtraParams, Scenarios: testutil.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, @@ -91,6 +92,7 @@ func TestBedrock(t *testing.T) { Embedding: true, ListModels: true, Reasoning: true, + PromptCaching: true, BatchCreate: true, BatchList: true, BatchRetrieve: true, diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index 714a2f357..c279a694b 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -60,87 +60,74 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte var reasoningText string if response.Output.Message != nil { - if len(response.Output.Message.Content) == 1 && response.Output.Message.Content[0].Text != nil { - contentStr = response.Output.Message.Content[0].Text - } else { - // Check if this is a single tool use for structured output (response_format) - // If there's only one tool use and no other content, treat it as structured output - if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok { - if len(response.Output.Message.Content) > 0 && response.Output.Message.Content[0].ToolUse != nil && structuredOutputToolName == response.Output.Message.Content[0].ToolUse.Name { - toolUse := response.Output.Message.Content[0].ToolUse - // Marshal the tool input to JSON string and use as content - if toolUse.Input != nil { - if argBytes, err := sonic.Marshal(toolUse.Input); err == nil { + for _, contentBlock := range response.Output.Message.Content { + // Handle text content + if contentBlock.Text != nil && *contentBlock.Text != "" { + chatContentBlock := schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: contentBlock.Text, + } + contentBlocks = append(contentBlocks, chatContentBlock) + } + + if contentBlock.ToolUse != nil { + // Check if this is the structured output tool + if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok && contentBlock.ToolUse.Name == structuredOutputToolName { + // This is structured output - set contentStr and skip adding to toolCalls + if contentBlock.ToolUse.Input != nil { + if argBytes, err := sonic.Marshal(contentBlock.ToolUse.Input); err == nil { jsonStr := string(argBytes) contentStr = &jsonStr } else { - jsonStr := fmt.Sprintf("%v", toolUse.Input) + jsonStr := fmt.Sprintf("%v", contentBlock.ToolUse.Input) contentStr = &jsonStr } } + continue // Skip adding to toolCalls } - } else { - for _, contentBlock := range response.Output.Message.Content { - // Handle text content - if contentBlock.Text != nil && *contentBlock.Text != "" { - contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeText, - Text: contentBlock.Text, - }) + + // Regular tool call processing + var arguments string + if contentBlock.ToolUse.Input != nil { + if argBytes, err := sonic.Marshal(contentBlock.ToolUse.Input); err == nil { + arguments = string(argBytes) + } else { + arguments = fmt.Sprintf("%v", contentBlock.ToolUse.Input) } + } else { + arguments = "{}" + } - // Handle tool use - if contentBlock.ToolUse != nil { - // Marshal the tool input to JSON string - var arguments string - if contentBlock.ToolUse.Input != nil { - if argBytes, err := sonic.Marshal(contentBlock.ToolUse.Input); err == nil { - arguments = string(argBytes) - } else { - arguments = fmt.Sprintf("%v", contentBlock.ToolUse.Input) - } - } else { - arguments = "{}" - } + toolUseID := contentBlock.ToolUse.ToolUseID + toolUseName := contentBlock.ToolUse.Name - // Create copies of the values to avoid range loop variable capture - toolUseID := contentBlock.ToolUse.ToolUseID - toolUseName := contentBlock.ToolUse.Name - - toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ - Index: uint16(len(toolCalls)), - Type: schemas.Ptr("function"), - ID: &toolUseID, - Function: schemas.ChatAssistantMessageToolCallFunction{ - Name: &toolUseName, - Arguments: arguments, - }, - }) - } + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: schemas.Ptr("function"), + ID: &toolUseID, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolUseName, + Arguments: arguments, + }, + }) + } - // Handle reasoning content - if contentBlock.ReasoningContent != nil { - if contentBlock.ReasoningContent.ReasoningText == nil { - continue - } - reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{ - Index: len(reasoningDetails), - Type: schemas.BifrostReasoningDetailsTypeText, - Text: schemas.Ptr(contentBlock.ReasoningContent.ReasoningText.Text), - Signature: contentBlock.ReasoningContent.ReasoningText.Signature, - }) - reasoningText += contentBlock.ReasoningContent.ReasoningText.Text + "\n" - } + // Handle reasoning content + if contentBlock.ReasoningContent != nil { + if contentBlock.ReasoningContent.ReasoningText == nil { + continue } + reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{ + Index: len(reasoningDetails), + Type: schemas.BifrostReasoningDetailsTypeText, + Text: schemas.Ptr(contentBlock.ReasoningContent.ReasoningText.Text), + Signature: contentBlock.ReasoningContent.ReasoningText.Signature, + }) + reasoningText += contentBlock.ReasoningContent.ReasoningText.Text + "\n" } } } - if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText { - contentStr = contentBlocks[0].Text - contentBlocks = nil - } - // Create the message content messageContent := schemas.ChatMessageContent{ ContentStr: contentStr, diff --git a/core/providers/bedrock/responses.go b/core/providers/bedrock/responses.go index a7cca77d4..bb237b695 100644 --- a/core/providers/bedrock/responses.go +++ b/core/providers/bedrock/responses.go @@ -1280,6 +1280,13 @@ func (request *BedrockConverseRequest) ToBifrostResponsesRequest(ctx *context.Co } bifrostReq.Params.Tools = append(bifrostReq.Params.Tools, bifrostTool) + } else if tool.CachePoint != nil && !schemas.IsNovaModel(bifrostReq.Model) { + // add cache control to last tool in tools array + if len(bifrostReq.Params.Tools) > 0 { + bifrostReq.Params.Tools[len(bifrostReq.Params.Tools)-1].CacheControl = &schemas.CacheControl{ + Type: schemas.CacheControlTypeEphemeral, + } + } } } } @@ -1561,6 +1568,14 @@ func ToBedrockResponsesRequest(ctx *context.Context, bifrostReq *schemas.Bifrost }, } bedrockTools = append(bedrockTools, bedrockTool) + + if tool.CacheControl != nil && !schemas.IsNovaModel(bifrostReq.Model) { + bedrockTools = append(bedrockTools, BedrockTool{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } } @@ -1688,6 +1703,13 @@ func ToBedrockConverseResponse(bifrostResp *schemas.BifrostResponsesResponse) (* bedrockResp.Usage.InputTokens = bifrostResp.Usage.InputTokens bedrockResp.Usage.OutputTokens = bifrostResp.Usage.OutputTokens bedrockResp.Usage.TotalTokens = bifrostResp.Usage.TotalTokens + + if bifrostResp.Usage.InputTokensDetails != nil { + bedrockResp.Usage.CacheReadInputTokens = bifrostResp.Usage.InputTokensDetails.CachedTokens + } + if bifrostResp.Usage.OutputTokensDetails != nil { + bedrockResp.Usage.CacheWriteInputTokens = bifrostResp.Usage.OutputTokensDetails.CachedTokens + } } // Set metrics @@ -2357,8 +2379,8 @@ func ConvertBedrockMessagesToBifrostMessages(ctx *context.Context, bedrockMessag var bifrostMessages []schemas.ResponsesMessage // Convert system messages first - for _, sysMsg := range systemMessages { - systemBifrostMsgs := convertBedrockSystemMessageToBifrostMessages(&sysMsg) + systemBifrostMsgs := convertBedrockSystemMessageToBifrostMessages(systemMessages) + if len(systemBifrostMsgs) > 0 { bifrostMessages = append(bifrostMessages, systemBifrostMsgs...) } @@ -2388,6 +2410,13 @@ func convertBifrostMessageToBedrockSystemMessages(msg *schemas.ResponsesMessage) systemMessages = append(systemMessages, BedrockSystemMessage{ Text: block.Text, }) + if block.CacheControl != nil { + systemMessages = append(systemMessages, BedrockSystemMessage{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } } } @@ -2418,24 +2447,40 @@ func convertBifrostMessageToBedrockMessage(msg *schemas.ResponsesMessage) *Bedro } // convertBedrockSystemMessageToBifrostMessages converts a Bedrock system message to Bifrost messages -func convertBedrockSystemMessageToBifrostMessages(sysMsg *BedrockSystemMessage) []schemas.ResponsesMessage { - if sysMsg.Text != nil { - systemRole := schemas.ResponsesInputMessageRoleSystem - msgType := schemas.ResponsesMessageTypeMessage - return []schemas.ResponsesMessage{{ - Type: &msgType, - Role: &systemRole, - Content: &schemas.ResponsesMessageContent{ - ContentBlocks: []schemas.ResponsesMessageContentBlock{ - { - Type: schemas.ResponsesInputMessageContentBlockTypeText, - Text: sysMsg.Text, +func convertBedrockSystemMessageToBifrostMessages(systemMessages []BedrockSystemMessage) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage + + for _, sysMsg := range systemMessages { + if sysMsg.CachePoint != nil { + // add it to last content block of last message + if len(bifrostMessages) > 0 { + lastMessage := &bifrostMessages[len(bifrostMessages)-1] + if lastMessage.Content != nil && len(lastMessage.Content.ContentBlocks) > 0 { + lastMessage.Content.ContentBlocks[len(lastMessage.Content.ContentBlocks)-1].CacheControl = &schemas.CacheControl{ + Type: schemas.CacheControlTypeEphemeral, + } + } + } + } + if sysMsg.Text != nil { + systemRole := schemas.ResponsesInputMessageRoleSystem + msgType := schemas.ResponsesMessageTypeMessage + bifrostMessages = append(bifrostMessages, schemas.ResponsesMessage{ + Type: &msgType, + Role: &systemRole, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: sysMsg.Text, + }, }, }, - }, - }} + }) + } + } - return []schemas.ResponsesMessage{} + return bifrostMessages } // Helper to convert Bedrock role to Bifrost role @@ -2608,6 +2653,16 @@ func convertSingleBedrockMessageToBifrostMessages(ctx *context.Context, msg *Bed } } outputMessages = append(outputMessages, resultMsg) + } else if block.CachePoint != nil { + // add cache control to last content block of last message + if len(outputMessages) > 0 { + lastMessage := &outputMessages[len(outputMessages)-1] + if lastMessage.Content != nil && len(lastMessage.Content.ContentBlocks) > 0 { + lastMessage.Content.ContentBlocks[len(lastMessage.Content.ContentBlocks)-1].CacheControl = &schemas.CacheControl{ + Type: schemas.CacheControlTypeEphemeral, + } + } + } } } @@ -2690,7 +2745,6 @@ func convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(content s for _, block := range content.ContentBlocks { bedrockBlock := BedrockContentBlock{} - switch block.Type { case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: bedrockBlock.Text = block.Text @@ -2717,6 +2771,14 @@ func convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(content s } blocks = append(blocks, bedrockBlock) + + if block.CacheControl != nil { + blocks = append(blocks, BedrockContentBlock{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } } diff --git a/core/providers/bedrock/types.go b/core/providers/bedrock/types.go index 77437c3ef..d24b505c8 100644 --- a/core/providers/bedrock/types.go +++ b/core/providers/bedrock/types.go @@ -156,6 +156,7 @@ type BedrockMessage struct { type BedrockSystemMessage struct { Text *string `json:"text,omitempty"` // Text system message GuardContent *BedrockGuardContent `json:"guardContent,omitempty"` // Guard content for guardrails + CachePoint *BedrockCachePoint `json:"cachePoint,omitempty"` // Cache point for the system message } // BedrockContentBlock represents a content block that can be text, image, document, toolUse, or toolResult @@ -183,6 +184,20 @@ type BedrockContentBlock struct { // For Tool Call Result content JSON interface{} `json:"json,omitempty"` + + // Cache point for the content block + CachePoint *BedrockCachePoint `json:"cachePoint,omitempty"` +} + +type BedrockCachePointType string + +const ( + BedrockCachePointTypeDefault BedrockCachePointType = "default" +) + +// BedrockCachePoint +type BedrockCachePoint struct { + Type BedrockCachePointType `json:"type"` } // BedrockImageSource represents image content @@ -267,7 +282,8 @@ type BedrockToolConfig struct { // BedrockTool represents a tool definition type BedrockTool struct { - ToolSpec *BedrockToolSpec `json:"toolSpec,omitempty"` // Tool specification + ToolSpec *BedrockToolSpec `json:"toolSpec,omitempty"` // Tool specification + CachePoint *BedrockCachePoint `json:"cachePoint,omitempty"` // Cache point for the tool } // BedrockToolSpec represents the specification of a tool @@ -626,7 +642,7 @@ type BedrockFileRetrieveRequest struct { Bucket string `json:"bucket"` Prefix string `json:"prefix"` S3Uri string `json:"s3Uri"` // Full S3 URI (s3://bucket/key) - ETag string `json:"etag"` // S3 ETag + ETag string `json:"etag"` // S3 ETag } // BedrockFileRetrieveResponse wraps S3 HeadObject response @@ -644,7 +660,7 @@ type BedrockFileDeleteRequest struct { Bucket string `json:"bucket"` Prefix string `json:"prefix"` S3Uri string `json:"s3Uri"` // Full S3 URI (s3://bucket/key) - ETag string `json:"etag"` // S3 ETag + ETag string `json:"etag"` // S3 ETag } // BedrockFileDeleteResponse wraps S3 DeleteObject response @@ -658,7 +674,7 @@ type BedrockFileContentRequest struct { Bucket string `json:"bucket"` Prefix string `json:"prefix,omitempty"` S3Uri string `json:"s3Uri"` // Full S3 URI (s3://bucket/key) - ETag string `json:"etag"` // S3 ETag + ETag string `json:"etag"` // S3 ETag } // BedrockFileContentResponse wraps S3 GetObject response diff --git a/core/providers/bedrock/utils.go b/core/providers/bedrock/utils.go index 2102269ff..6c506fbf1 100644 --- a/core/providers/bedrock/utils.go +++ b/core/providers/bedrock/utils.go @@ -27,7 +27,7 @@ func convertChatParameters(ctx *context.Context, bifrostReq *schemas.BifrostChat responseFormatTool := convertResponseFormatToTool(ctx, bifrostReq.Params) // Convert tool config - if toolConfig := convertToolConfig(bifrostReq.Params); toolConfig != nil { + if toolConfig := convertToolConfig(bifrostReq.Model, bifrostReq.Params); toolConfig != nil { bedrockReq.ToolConfig = toolConfig } @@ -204,11 +204,11 @@ func convertMessages(bifrostMessages []schemas.ChatMessage) ([]BedrockMessage, [ switch msg.Role { case schemas.ChatMessageRoleSystem: // Convert system message - systemMsg, err := convertSystemMessage(msg) + systemMsgs, err := convertSystemMessages(msg) if err != nil { return nil, nil, fmt.Errorf("failed to convert system message: %w", err) } - systemMessages = append(systemMessages, systemMsg) + systemMessages = append(systemMessages, systemMsgs...) case schemas.ChatMessageRoleUser, schemas.ChatMessageRoleAssistant: // Convert regular message @@ -234,29 +234,33 @@ func convertMessages(bifrostMessages []schemas.ChatMessage) ([]BedrockMessage, [ return messages, systemMessages, nil } -// convertSystemMessage converts a Bifrost system message to Bedrock format -func convertSystemMessage(msg schemas.ChatMessage) (BedrockSystemMessage, error) { - systemMsg := BedrockSystemMessage{} +// convertSystemMessages converts a Bifrost system message to Bedrock format +func convertSystemMessages(msg schemas.ChatMessage) ([]BedrockSystemMessage, error) { + systemMsgs := []BedrockSystemMessage{} // Convert content if msg.Content.ContentStr != nil { - systemMsg.Text = msg.Content.ContentStr + systemMsgs = append(systemMsgs, BedrockSystemMessage{ + Text: msg.Content.ContentStr, + }) } else if msg.Content.ContentBlocks != nil { - // For system messages, we only support text content - // Combine all text blocks into a single string - var textParts []string for _, block := range msg.Content.ContentBlocks { if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil { - textParts = append(textParts, *block.Text) + systemMsgs = append(systemMsgs, BedrockSystemMessage{ + Text: block.Text, + }) + if block.CacheControl != nil { + systemMsgs = append(systemMsgs, BedrockSystemMessage{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } } - if len(textParts) > 0 { - combined := strings.Join(textParts, "\n") - systemMsg.Text = &combined - } } - return systemMsg, nil + return systemMsgs, nil } // convertMessage converts a Bifrost message to Bedrock format @@ -323,6 +327,14 @@ func convertToolMessage(msg schemas.ChatMessage) (BedrockMessage, error) { toolResultContent = append(toolResultContent, BedrockContentBlock{ Text: block.Text, }) + // Cache point must be in a separate block + if block.CacheControl != nil { + toolResultContent = append(toolResultContent, BedrockContentBlock{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } case schemas.ChatContentBlockTypeImage: if block.ImageURLStruct != nil { @@ -333,6 +345,14 @@ func convertToolMessage(msg schemas.ChatMessage) (BedrockMessage, error) { toolResultContent = append(toolResultContent, BedrockContentBlock{ Image: imageSource, }) + // Cache point must be in a separate block + if block.CacheControl != nil { + toolResultContent = append(toolResultContent, BedrockContentBlock{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } } } @@ -363,11 +383,11 @@ func convertContent(content schemas.ChatMessageContent) ([]BedrockContentBlock, } else if content.ContentBlocks != nil { // Multi-modal content for _, block := range content.ContentBlocks { - bedrockBlock, err := convertContentBlock(block) + bedrockBlocks, err := convertContentBlock(block) if err != nil { return nil, fmt.Errorf("failed to convert content block: %w", err) } - contentBlocks = append(contentBlocks, bedrockBlock) + contentBlocks = append(contentBlocks, bedrockBlocks...) } } @@ -375,32 +395,54 @@ func convertContent(content schemas.ChatMessageContent) ([]BedrockContentBlock, } // convertContentBlock converts a Bifrost content block to Bedrock format -func convertContentBlock(block schemas.ChatContentBlock) (BedrockContentBlock, error) { +func convertContentBlock(block schemas.ChatContentBlock) ([]BedrockContentBlock, error) { switch block.Type { case schemas.ChatContentBlockTypeText: - return BedrockContentBlock{ - Text: block.Text, - }, nil + blocks := []BedrockContentBlock{ + { + Text: block.Text, + }, + } + // Cache point must be in a separate block + if block.CacheControl != nil { + blocks = append(blocks, BedrockContentBlock{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } + return blocks, nil case schemas.ChatContentBlockTypeImage: if block.ImageURLStruct == nil { - return BedrockContentBlock{}, fmt.Errorf("image_url block missing image_url field") + return nil, fmt.Errorf("image_url block missing image_url field") } imageSource, err := convertImageToBedrockSource(block.ImageURLStruct.URL) if err != nil { - return BedrockContentBlock{}, fmt.Errorf("failed to convert image: %w", err) + return nil, fmt.Errorf("failed to convert image: %w", err) } - return BedrockContentBlock{ - Image: imageSource, - }, nil + blocks := []BedrockContentBlock{ + { + Image: imageSource, + }, + } + // Cache point must be in a separate block + if block.CacheControl != nil { + blocks = append(blocks, BedrockContentBlock{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } + return blocks, nil case schemas.ChatContentBlockTypeInputAudio: // Bedrock doesn't support audio input in Converse API - return BedrockContentBlock{}, fmt.Errorf("audio input not supported in Bedrock Converse API") + return nil, fmt.Errorf("audio input not supported in Bedrock Converse API") default: - return BedrockContentBlock{}, fmt.Errorf("unsupported content block type: %s", block.Type) + return nil, fmt.Errorf("unsupported content block type: %s", block.Type) } } @@ -572,7 +614,7 @@ func convertInferenceConfig(params *schemas.ChatParameters) *BedrockInferenceCon } // convertToolConfig converts Bifrost tools to Bedrock tool config -func convertToolConfig(params *schemas.ChatParameters) *BedrockToolConfig { +func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToolConfig { if len(params.Tools) == 0 { return nil } @@ -616,6 +658,14 @@ func convertToolConfig(params *schemas.ChatParameters) *BedrockToolConfig { }, } bedrockTools = append(bedrockTools, bedrockTool) + + if tool.CacheControl != nil && !schemas.IsNovaModel(model) { + bedrockTools = append(bedrockTools, BedrockTool{ + CachePoint: &BedrockCachePoint{ + Type: BedrockCachePointTypeDefault, + }, + }) + } } } diff --git a/core/providers/openai/responses_marshal_test.go b/core/providers/openai/responses_marshal_test.go index e7f08f291..59997584c 100644 --- a/core/providers/openai/responses_marshal_test.go +++ b/core/providers/openai/responses_marshal_test.go @@ -478,4 +478,3 @@ func TestOpenAIResponsesRequest_MarshalJSON_RoundTrip(t *testing.T) { } }) } - diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 2cab76756..3915eafba 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -79,24 +79,105 @@ type OpenAIChatAssistantMessage struct { // MarshalJSON implements custom JSON marshalling for OpenAIChatRequest. // It excludes the reasoning field and instead marshals reasoning_effort // with the value of Reasoning.Effort if not nil. +// It also removes cache_control from messages, their content blocks, and tools. func (r *OpenAIChatRequest) MarshalJSON() ([]byte, error) { if r == nil { return []byte("null"), nil } type Alias OpenAIChatRequest + // First pass: check if we need to modify any messages + needsCopy := false + for _, msg := range r.Messages { + if hasCacheControlInChatMessage(msg) { + needsCopy = true + break + } + } + + // Process messages if needed + var processedMessages []OpenAIMessage + if needsCopy { + processedMessages = make([]OpenAIMessage, len(r.Messages)) + for i, msg := range r.Messages { + if !hasCacheControlInChatMessage(msg) { + // No modification needed, use original + processedMessages[i] = msg + continue + } + + // Copy message + processedMessages[i] = msg + + // Strip CacheControl from content blocks if needed + if msg.Content != nil && msg.Content.ContentBlocks != nil { + contentCopy := *msg.Content + contentCopy.ContentBlocks = make([]schemas.ChatContentBlock, len(msg.Content.ContentBlocks)) + for j, block := range msg.Content.ContentBlocks { + if block.CacheControl != nil { + blockCopy := block + blockCopy.CacheControl = nil + contentCopy.ContentBlocks[j] = blockCopy + } else { + contentCopy.ContentBlocks[j] = block + } + } + processedMessages[i].Content = &contentCopy + } + } + } else { + processedMessages = r.Messages + } + + // Process tools if needed + var processedTools []schemas.ChatTool + if len(r.Tools) > 0 { + needsToolCopy := false + for _, tool := range r.Tools { + if tool.CacheControl != nil { + needsToolCopy = true + break + } + } + + if needsToolCopy { + processedTools = make([]schemas.ChatTool, len(r.Tools)) + for i, tool := range r.Tools { + if tool.CacheControl != nil { + toolCopy := tool + toolCopy.CacheControl = nil + processedTools[i] = toolCopy + } else { + processedTools[i] = tool + } + } + } else { + processedTools = r.Tools + } + } else { + processedTools = r.Tools + } + // Aux struct: // - Alias embeds all original fields + // - Messages shadows the embedded Messages field to use processed messages + // - Tools shadows the embedded Tools field to use processed tools // - Reasoning shadows the embedded ChatParameters.Reasoning // so that "reasoning" is not emitted // - ReasoningEffort is emitted as "reasoning_effort" aux := struct { *Alias + // Shadow the embedded "messages" field to use processed messages + Messages []OpenAIMessage `json:"messages"` + // Shadow the embedded "tools" field to use processed tools + Tools []schemas.ChatTool `json:"tools,omitempty"` // Shadow the embedded "reasoning" field and omit it Reasoning *schemas.ChatReasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` }{ - Alias: (*Alias)(r), + Alias: (*Alias)(r), + Messages: processedMessages, + Tools: processedTools, } // DO NOT set aux.Reasoning → it stays nil and is omitted via omitempty, and also due to double reference to the same json field. @@ -169,17 +250,125 @@ func (r *OpenAIResponsesRequestInput) UnmarshalJSON(data []byte) error { return fmt.Errorf("openai responses request input is neither a string nor an array of responses messages") } -// MarshalJSON implements custom JSON marshalling for ResponsesRequestInput. func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { if r.OpenAIResponsesRequestInputStr != nil { return sonic.Marshal(*r.OpenAIResponsesRequestInputStr) } if r.OpenAIResponsesRequestInputArray != nil { - return sonic.Marshal(r.OpenAIResponsesRequestInputArray) + // First pass: check if we need to modify anything + needsCopy := false + for _, msg := range r.OpenAIResponsesRequestInputArray { + if hasCacheControl(msg) { + needsCopy = true + break + } + } + + // If no CacheControl found anywhere, marshal as-is + if !needsCopy { + return sonic.Marshal(r.OpenAIResponsesRequestInputArray) + } + + // Only copy messages that have CacheControl + messagesCopy := make([]schemas.ResponsesMessage, len(r.OpenAIResponsesRequestInputArray)) + for i, msg := range r.OpenAIResponsesRequestInputArray { + if !hasCacheControl(msg) { + // No modification needed, use original + messagesCopy[i] = msg + continue + } + + // Copy only this message + messagesCopy[i] = msg + + // Strip CacheControl from content blocks if needed + if msg.Content != nil && msg.Content.ContentBlocks != nil { + contentCopy := *msg.Content + contentCopy.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, len(msg.Content.ContentBlocks)) + hasContentCache := false + for j, block := range msg.Content.ContentBlocks { + if block.CacheControl != nil { + hasContentCache = true + blockCopy := block + blockCopy.CacheControl = nil + contentCopy.ContentBlocks[j] = blockCopy + } else { + contentCopy.ContentBlocks[j] = block + } + } + if hasContentCache { + messagesCopy[i].Content = &contentCopy + } + } + + // Strip CacheControl from tool message output blocks if needed + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Output != nil { + if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + hasToolCache := false + for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { + if block.CacheControl != nil { + hasToolCache = true + break + } + } + + if hasToolCache { + toolMsgCopy := *msg.ResponsesToolMessage + outputCopy := *msg.ResponsesToolMessage.Output + outputCopy.ResponsesFunctionToolCallOutputBlocks = make([]schemas.ResponsesMessageContentBlock, len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks)) + for j, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { + if block.CacheControl != nil { + blockCopy := block + blockCopy.CacheControl = nil + outputCopy.ResponsesFunctionToolCallOutputBlocks[j] = blockCopy + } else { + outputCopy.ResponsesFunctionToolCallOutputBlocks[j] = block + } + } + toolMsgCopy.Output = &outputCopy + messagesCopy[i].ResponsesToolMessage = &toolMsgCopy + } + } + } + } + return sonic.Marshal(messagesCopy) } return sonic.Marshal(nil) } +// Helper function to check if a chat message has any CacheControl fields +func hasCacheControlInChatMessage(msg OpenAIMessage) bool { + if msg.Content != nil && msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.CacheControl != nil { + return true + } + } + } + return false +} + +// Helper function to check if a responses message has any CacheControl fields +func hasCacheControl(msg schemas.ResponsesMessage) bool { + if msg.Content != nil && msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.CacheControl != nil { + return true + } + } + } + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Output != nil { + if msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { + if block.CacheControl != nil { + return true + } + } + } + } + return false +} + type OpenAIResponsesRequest struct { Model string `json:"model"` Input OpenAIResponsesRequestInput `json:"input"` @@ -202,6 +391,35 @@ func (r *OpenAIResponsesRequest) MarshalJSON() ([]byte, error) { return nil, err } + // Process tools if needed + var processedTools []schemas.ResponsesTool + if len(r.Tools) > 0 { + needsToolCopy := false + for _, tool := range r.Tools { + if tool.CacheControl != nil { + needsToolCopy = true + break + } + } + + if needsToolCopy { + processedTools = make([]schemas.ResponsesTool, len(r.Tools)) + for i, tool := range r.Tools { + if tool.CacheControl != nil { + toolCopy := tool + toolCopy.CacheControl = nil + processedTools[i] = toolCopy + } else { + processedTools[i] = tool + } + } + } else { + processedTools = r.Tools + } + } else { + processedTools = r.Tools + } + // Aux struct: // - Alias embeds all original fields // - Input shadows the embedded Input field and uses json.RawMessage to preserve custom marshaling @@ -213,9 +431,12 @@ func (r *OpenAIResponsesRequest) MarshalJSON() ([]byte, error) { Input json.RawMessage `json:"input"` // Shadow the embedded "reasoning" field to modify it Reasoning *schemas.ResponsesParametersReasoning `json:"reasoning,omitempty"` + // Shadow the embedded "tools" field to use processed tools + Tools []schemas.ResponsesTool `json:"tools,omitempty"` }{ Alias: (*Alias)(r), Input: json.RawMessage(inputBytes), + Tools: processedTools, } // Copy reasoning but set MaxTokens to nil diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index bfb70d5f7..9ec7c6193 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -243,9 +243,10 @@ const ( // ChatTool represents a tool definition. type ChatTool struct { - Type ChatToolType `json:"type"` - Function *ChatToolFunction `json:"function,omitempty"` // Function definition - Custom *ChatToolCustom `json:"custom,omitempty"` // Custom tool definition + Type ChatToolType `json:"type"` + Function *ChatToolFunction `json:"function,omitempty"` // Function definition + Custom *ChatToolCustom `json:"custom,omitempty"` // Custom tool definition + CacheControl *CacheControl `json:"cache_control,omitempty"` // Cache control for the tool } // ChatToolFunction represents a function definition. @@ -595,6 +596,20 @@ type ChatContentBlock struct { ImageURLStruct *ChatInputImage `json:"image_url,omitempty"` InputAudio *ChatInputAudio `json:"input_audio,omitempty"` File *ChatInputFile `json:"file,omitempty"` + + // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) + CacheControl *CacheControl `json:"cache_control,omitempty"` +} + +type CacheControlType string + +const ( + CacheControlTypeEphemeral CacheControlType = "ephemeral" +) + +type CacheControl struct { + Type CacheControlType `json:"type"` + TTL *string `json:"ttl,omitempty"` // "1m" | "1h" } // ChatInputImage represents image data in a message. diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 5b974fcaa..cf0d83582 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -412,6 +412,9 @@ type ResponsesMessageContentBlock struct { *ResponsesOutputMessageContentText // Normal text output from the model *ResponsesOutputMessageContentRefusal // Model refusal to answer + + // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) + CacheControl *CacheControl `json:"cache_control,omitempty"` } type ResponsesInputMessageContentBlockImage struct { @@ -1036,6 +1039,9 @@ type ResponsesTool struct { Name *string `json:"name,omitempty"` // Common name field (Function, Custom tools) Description *string `json:"description,omitempty"` // Common description field (Function, Custom tools) + // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) + CacheControl *CacheControl `json:"cache_control,omitempty"` + *ResponsesToolFunction *ResponsesToolFileSearch *ResponsesToolComputerUsePreview diff --git a/core/schemas/utils.go b/core/schemas/utils.go index fbad2c13a..66ec247f9 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -1039,6 +1039,11 @@ func deepCopyResponsesMessageContentBlock(original ResponsesMessageContentBlock) return copy } +// IsNovaModel checks if the model is a Nova model. +func IsNovaModel(model string) bool { + return strings.Contains(model, "nova") +} + // IsAnthropicModel checks if the model is an Anthropic model. func IsAnthropicModel(model string) bool { return strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") diff --git a/tests/integrations/config.yml b/tests/integrations/config.yml index 9fdb2cdd1..c75e1fa56 100644 --- a/tests/integrations/config.yml +++ b/tests/integrations/config.yml @@ -107,7 +107,7 @@ providers: - "gemini-2.0-flash-001" bedrock: - chat: "anthropic.claude-3-5-sonnet-20240620-v1:0" + chat: "us.anthropic.claude-sonnet-4-5-20250929-v1:0" vision: "anthropic.claude-3-5-sonnet-20240620-v1:0" tools: "anthropic.claude-3-5-sonnet-20240620-v1:0" streaming: "anthropic.claude-3-5-sonnet-20240620-v1:0" @@ -164,6 +164,7 @@ provider_scenarios: transcription_streaming: true embeddings: true thinking: true + prompt_caching: false list_models: true responses: true responses_image: true @@ -201,6 +202,7 @@ provider_scenarios: transcription_streaming: false embeddings: false thinking: true + prompt_caching: true list_models: true responses: true responses_image: true @@ -238,6 +240,7 @@ provider_scenarios: transcription_streaming: true embeddings: true thinking: true + prompt_caching: false list_models: true responses: true responses_image: true @@ -275,6 +278,7 @@ provider_scenarios: transcription_streaming: false embeddings: true thinking: true + prompt_caching: true list_models: true responses: true responses_image: true @@ -312,6 +316,7 @@ provider_scenarios: transcription_streaming: false embeddings: true thinking: false + prompt_caching: false list_models: false responses: true responses_image: true @@ -354,6 +359,7 @@ scenario_capabilities: transcription_streaming: "transcription" embeddings: "embeddings" thinking: "thinking" + prompt_caching: "chat" list_models: "chat" langchain_structured_output: "chat" # LangChain structured output uses chat capability pydantic_structured_output: "chat" # Structured output uses chat capability diff --git a/tests/integrations/tests/test_anthropic.py b/tests/integrations/tests/test_anthropic.py index c9e817c47..25b552a24 100644 --- a/tests/integrations/tests/test_anthropic.py +++ b/tests/integrations/tests/test_anthropic.py @@ -36,6 +36,9 @@ 25. Batch API - batch cancel 26. Batch API - batch results 27. Batch API - end-to-end workflow +28. Prompt caching - system message checkpoint +29. Prompt caching - messages checkpoint +30. Prompt caching - tools checkpoint """ import logging @@ -60,6 +63,7 @@ WEATHER_TOOL, CALCULATOR_TOOL, ALL_TOOLS, + PROMPT_CACHING_TOOLS, create_batch_jsonl_content, mock_tool_response, assert_valid_chat_response, @@ -78,6 +82,7 @@ # Anthropic-specific test data ANTHROPIC_THINKING_PROMPT, ANTHROPIC_THINKING_STREAMING_PROMPT, + PROMPT_CACHING_LARGE_CONTEXT, # Files API utilities assert_valid_file_response, assert_valid_file_list_response, @@ -1413,6 +1418,195 @@ def test_27_batch_e2e(self, anthropic_client, test_config, provider, model): except Exception as e: print(f"Cleanup info: Could not cancel batch: {e}") + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("prompt_caching")) + def test_28_prompt_caching_system(self, anthropic_client, provider, model): + """Test Case 28: Prompt caching with system message checkpoint""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for prompt_caching scenario") + + print(f"\n=== Testing System Message Caching for provider {provider} ===") + print("First request: Creating cache with system message checkpoint...") + + system_messages = [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents." + }, + { + "type": "text", + "text": PROMPT_CACHING_LARGE_CONTEXT, + "cache_control": {"type": "ephemeral"} + } + ] + + # First request - should create cache + response1 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + system=system_messages, + messages=[ + { + "role": "user", + "content": "What are the key elements of contract formation?" + } + ], + max_tokens=1024 + ) + + # Validate first response + assert_valid_chat_response(response1) + assert hasattr(response1, "usage"), "Response should have usage information" + cache_write_tokens = validate_cache_write(response1.usage, "First request") + + # Second request with same system - should hit cache + print("\nSecond request: Hitting cache with same system checkpoint...") + response2 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + system=system_messages, # Same system messages with cache_control + messages=[ + { + "role": "user", + "content": "What is the purpose of a force majeure clause?" + } + ], + max_tokens=1024 + ) + + # Validate second response + assert_valid_chat_response(response2) + cache_read_tokens = validate_cache_read(response2.usage, "Second request") + + # Validate that cache read tokens are approximately equal to cache creation tokens + assert abs(cache_write_tokens - cache_read_tokens) < 100, \ + f"Cache read tokens ({cache_read_tokens}) should be close to cache creation tokens ({cache_write_tokens})" + + print(f"✓ System caching validated - Cache created: {cache_write_tokens} tokens, " + f"Cache read: {cache_read_tokens} tokens") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("prompt_caching")) + def test_29_prompt_caching_messages(self, anthropic_client, provider, model): + """Test Case 29: Prompt caching with messages checkpoint""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for prompt_caching scenario") + + print(f"\n=== Testing Messages Caching for provider {provider} ===") + print("First request: Creating cache with messages checkpoint...") + + # First request with cache control in user message + response1 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Here is a large legal document to analyze:" + }, + { + "type": "text", + "text": PROMPT_CACHING_LARGE_CONTEXT, + "cache_control": {"type": "ephemeral"} + }, + { + "type": "text", + "text": "What are the main indemnification principles?" + } + ] + } + ], + max_tokens=1024 + ) + + assert_valid_chat_response(response1) + assert hasattr(response1, "usage"), "Response should have usage information" + cache_write_tokens = validate_cache_write(response1.usage, "First request") + + # Second request with same cached content + print("\nSecond request: Hitting cache with same messages checkpoint...") + response2 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Here is a large legal document to analyze:" + }, + { + "type": "text", + "text": PROMPT_CACHING_LARGE_CONTEXT, + "cache_control": {"type": "ephemeral"} + }, + { + "type": "text", + "text": "Summarize the dispute resolution methods." + } + ] + } + ], + max_tokens=1024 + ) + + assert_valid_chat_response(response2) + cache_read_tokens = validate_cache_read(response2.usage, "Second request") + + # Validate that cache read tokens are approximately equal to cache creation tokens + assert abs(cache_write_tokens - cache_read_tokens) < 100, \ + f"Cache read tokens ({cache_read_tokens}) should be close to cache creation tokens ({cache_write_tokens})" + + print(f"✓ Messages caching validated - Cache created: {cache_write_tokens} tokens, " + f"Cache read: {cache_read_tokens} tokens") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("prompt_caching")) + def test_30_prompt_caching_tools(self, anthropic_client, provider, model): + """Test Case 30: Prompt caching with tools checkpoint (12 tools)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for prompt_caching scenario") + + print(f"\n=== Testing Tools Caching for provider {provider} ===") + print("First request: Creating cache with tools checkpoint...") + + # Convert tools to Anthropic format with cache control + tools = convert_to_anthropic_tools(PROMPT_CACHING_TOOLS) + # Add cache control to the last tool + tools[-1]["cache_control"] = {"type": "ephemeral"} + + # First request with tool cache control + response1 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + tools=tools, + messages=[ + { + "role": "user", + "content": "What's the weather in Boston?" + } + ], + max_tokens=1024 + ) + + assert hasattr(response1, "usage"), "Response should have usage information" + cache_write_tokens = validate_cache_write(response1.usage, "First request") + + # Second request with same tools + print("\nSecond request: Hitting cache with same tools checkpoint...") + response2 = anthropic_client.messages.create( + model=format_provider_model(provider, model), + tools=tools, + messages=[ + { + "role": "user", + "content": "Calculate 42 * 17" + } + ], + max_tokens=1024 + ) + + cache_read_tokens = validate_cache_read(response2.usage, "Second request") + + print(f"✓ Tools caching validated - Cache created: {cache_write_tokens} tokens, " + f"Cache read: {cache_read_tokens} tokens") + # Additional helper functions specific to Anthropic def serialize_anthropic_content(content_blocks: List[Any]) -> List[Dict[str, Any]]: @@ -1465,3 +1659,31 @@ def extract_anthropic_tool_calls(response: Any) -> List[Dict[str, Any]]: continue return tool_calls + +def validate_cache_write(usage: Any, operation: str) -> int: + """Validate cache write operation and return tokens written""" + print(f"{operation} usage - input_tokens: {usage.input_tokens}, " + f"cache_creation_input_tokens: {getattr(usage, 'cache_creation_input_tokens', 0)}, " + f"cache_read_input_tokens: {getattr(usage, 'cache_read_input_tokens', 0)}") + + assert hasattr(usage, 'cache_creation_input_tokens'), \ + f"{operation} should have cache_creation_input_tokens" + cache_write_tokens = getattr(usage, 'cache_creation_input_tokens', 0) + assert cache_write_tokens > 0, \ + f"{operation} should create cache (got {cache_write_tokens} tokens)" + + return cache_write_tokens + +def validate_cache_read(usage: Any, operation: str) -> int: + """Validate cache read operation and return tokens read""" + print(f"{operation} usage - input_tokens: {usage.input_tokens}, " + f"cache_creation_input_tokens: {getattr(usage, 'cache_creation_input_tokens', 0)}, " + f"cache_read_input_tokens: {getattr(usage, 'cache_read_input_tokens', 0)}") + + assert hasattr(usage, 'cache_read_input_tokens'), \ + f"{operation} should have cache_read_input_tokens" + cache_read_tokens = getattr(usage, 'cache_read_input_tokens', 0) + assert cache_read_tokens > 0, \ + f"{operation} should read from cache (got {cache_read_tokens} tokens)" + + return cache_read_tokens diff --git a/tests/integrations/tests/test_bedrock.py b/tests/integrations/tests/test_bedrock.py index 4df2d3167..368f068df 100644 --- a/tests/integrations/tests/test_bedrock.py +++ b/tests/integrations/tests/test_bedrock.py @@ -33,6 +33,11 @@ 18. Batch retrieve - Cross-provider 19. Batch cancel - Cross-provider 20. Batch end-to-end - Cross-provider + +Prompt Caching Tests: +21. Prompt caching with system message checkpoint +22. Prompt caching with messages checkpoint +23. Prompt caching with tools checkpoint """ import pytest @@ -48,9 +53,11 @@ BASE64_IMAGE, WEATHER_TOOL, CALCULATOR_TOOL, + PROMPT_CACHING_TOOLS, SIMPLE_CHAT_MESSAGES, MULTI_TURN_MESSAGES, MULTIPLE_TOOL_CALL_MESSAGES, + PROMPT_CACHING_LARGE_CONTEXT, mock_tool_response, assert_valid_chat_response, assert_has_tool_calls, @@ -1611,3 +1618,204 @@ def test_20_batch_e2e(self, test_config, provider, model): if "not authorized" in str(e).lower() or "access denied" in str(e).lower(): pytest.skip(f"Batch API not authorized: {e}") raise + + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("prompt_caching")) + def test_21_prompt_caching_system(self, bedrock_client, provider, model): + """Test Case 21: Prompt caching with system message checkpoint""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for prompt_caching scenario") + + print(f"\n=== Testing System Message Caching for provider {provider} ===") + print("First request: Creating cache with system message checkpoint...") + + system_with_cache = [ + {"text": "You are an AI assistant tasked with analyzing legal documents."}, + {"text": PROMPT_CACHING_LARGE_CONTEXT}, + {"cachePoint": {"type": "default"}} # Cache all preceding system content + ] + + # First request - should create cache + response1 = bedrock_client.converse( + modelId=format_provider_model(provider, model), + system=system_with_cache, + messages=[ + { + "role": "user", + "content": [{"text": "What are the key elements of contract formation?"}] + } + ] + ) + + # Validate first response + assert response1 is not None + assert "usage" in response1 + cache_write_tokens = validate_cache_write(response1["usage"], "First request") + + # Second request with same system - should hit cache + print("\nSecond request: Hitting cache with same system checkpoint...") + response2 = bedrock_client.converse( + modelId=format_provider_model(provider, model), + system=system_with_cache, + messages=[ + { + "role": "user", + "content": [{"text": "Explain the purpose of force majeure clauses."}] + } + ] + ) + + cache_read_tokens = validate_cache_read(response2["usage"], "Second request") + + # Validate that cache read tokens are approximately equal to cache write tokens + assert abs(cache_write_tokens - cache_read_tokens) < 100, \ + f"Cache read tokens ({cache_read_tokens}) should be close to cache write tokens ({cache_write_tokens})" + + print(f"✓ System caching validated - Cache created: {cache_write_tokens} tokens, " + f"Cache read: {cache_read_tokens} tokens") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("prompt_caching")) + def test_22_prompt_caching_messages(self, bedrock_client, provider, model): + """Test Case 22: Prompt caching with messages checkpoint""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for prompt_caching scenario") + + print(f"\n=== Testing Messages Caching for provider {provider} ===") + print("First request: Creating cache with messages checkpoint...") + + # First request with cache point in user message + response1 = bedrock_client.converse( + modelId=format_provider_model(provider, model), + messages=[ + { + "role": "user", + "content": [ + {"text": "Here is a large legal document to analyze:"}, + {"text": PROMPT_CACHING_LARGE_CONTEXT}, + {"cachePoint": {"type": "default"}}, # Cache all preceding message content + {"text": "What are the main indemnification principles?"} + ] + } + ] + ) + + assert response1 is not None + assert "usage" in response1 + cache_write_tokens = validate_cache_write(response1["usage"], "First request") + + # Second request with same cached content + print("\nSecond request: Hitting cache with same messages checkpoint...") + response2 = bedrock_client.converse( + modelId=format_provider_model(provider, model), + messages=[ + { + "role": "user", + "content": [ + {"text": "Here is a large legal document to analyze:"}, + {"text": PROMPT_CACHING_LARGE_CONTEXT}, + {"cachePoint": {"type": "default"}}, + {"text": "Summarize the dispute resolution methods."} + ] + } + ] + ) + + cache_read_tokens = validate_cache_read(response2["usage"], "Second request") + + # Validate that cache read tokens are approximately equal to cache write tokens + assert abs(cache_write_tokens - cache_read_tokens) < 100, \ + f"Cache read tokens ({cache_read_tokens}) should be close to cache write tokens ({cache_write_tokens})" + + print(f"✓ Messages caching validated - Cache created: {cache_write_tokens} tokens, " + f"Cache read: {cache_read_tokens} tokens") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("prompt_caching")) + def test_23_prompt_caching_tools(self, bedrock_client, provider, model): + """Test Case 23: Prompt caching with tools checkpoint (12 tools)""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for prompt_caching scenario") + + print(f"\n=== Testing Tools Caching for provider {provider} ===") + print("First request: Creating cache with tools checkpoint...") + + # Convert tools to Bedrock format (using 12 tools for larger cache test) + bedrock_tools = [] + for tool in PROMPT_CACHING_TOOLS: + bedrock_tools.append({ + "toolSpec": { + "name": tool["name"], + "description": tool["description"], + "inputSchema": {"json": tool["parameters"]} + } + }) + + # Add cache point after all tools + bedrock_tools.append({ + "cachePoint": {"type": "default"} # Cache all 12 tool definitions + }) + + # First request with tool cache point + tool_config = { + "tools": bedrock_tools, + } + + response1 = bedrock_client.converse( + modelId=format_provider_model(provider, model), + toolConfig=tool_config, + messages=[ + { + "role": "user", + "content": [{"text": "What's the weather in Boston?"}] + } + ] + ) + + assert response1 is not None + assert "usage" in response1 + cache_write_tokens = validate_cache_write(response1["usage"], "First request") + + # Second request with same tools + print("\nSecond request: Hitting cache with same tools checkpoint...") + response2 = bedrock_client.converse( + modelId=format_provider_model(provider, model), + toolConfig=tool_config, + messages=[ + { + "role": "user", + "content": [{"text": "Calculate 42 * 17"}] + } + ] + ) + + cache_read_tokens = validate_cache_read(response2["usage"], "Second request") + + # Validate that cache read tokens are approximately equal to cache write tokens + assert abs(cache_write_tokens - cache_read_tokens) < 100, \ + f"Cache read tokens ({cache_read_tokens}) should be close to cache write tokens ({cache_write_tokens})" + + print(f"✓ Tools caching validated - Cache created: {cache_write_tokens} tokens, " + f"Cache read: {cache_read_tokens} tokens") + +def validate_cache_write(usage: Dict[str, Any], operation: str) -> int: + """Validate cache write operation and return tokens written""" + print(f"{operation} usage - inputTokens: {usage.get('inputTokens', 0)}, " + f"cacheWriteInputTokens: {usage.get('cacheWriteInputTokens', 0)}, " + f"cacheReadInputTokens: {usage.get('cacheReadInputTokens', 0)}") + + cache_write_tokens = usage.get('cacheWriteInputTokens', 0) + assert cache_write_tokens > 0, \ + f"{operation} should write to cache (got {cache_write_tokens} tokens)" + + return cache_write_tokens + +def validate_cache_read(usage: Dict[str, Any], operation: str) -> int: + """Validate cache read operation and return tokens read""" + print(f"{operation} usage - inputTokens: {usage.get('inputTokens', 0)}, " + f"cacheWriteInputTokens: {usage.get('cacheWriteInputTokens', 0)}, " + f"cacheReadInputTokens: {usage.get('cacheReadInputTokens', 0)}") + + cache_read_tokens = usage.get('cacheReadInputTokens', 0) + assert cache_read_tokens > 0, \ + f"{operation} should read from cache (got {cache_read_tokens} tokens)" + + return cache_read_tokens diff --git a/tests/integrations/tests/utils/common.py b/tests/integrations/tests/utils/common.py index 2e0d0da32..927f2174c 100644 --- a/tests/integrations/tests/utils/common.py +++ b/tests/integrations/tests/utils/common.py @@ -85,6 +85,201 @@ class Config: }, } +# Tools for Prompt Caching Tests +PROMPT_CACHING_TOOLS = [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "description": "The city and state, e.g. San Francisco, CA", + "type": "string" + }, + "unit": { + "description": "The temperature unit", + "enum": ["celsius", "fahrenheit"], + "type": "string" + } + } + } + }, + { + "name": "get_current_time", + "description": "Get the current local time for a given city", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "description": "The city and country, e.g. London, UK", + "type": "string" + } + } + } + }, + { + "name": "unit_converter", + "description": "Convert a numeric value from one unit to another", + "parameters": { + "type": "object", + "required": ["value", "from_unit", "to_unit"], + "properties": { + "value": { + "description": "The numeric value to convert", + "type": "number" + }, + "from_unit": { + "description": "The source unit", + "type": "string" + }, + "to_unit": { + "description": "The target unit", + "type": "string" + } + } + } + }, + { + "name": "get_exchange_rate", + "description": "Get the current exchange rate between two currencies", + "parameters": { + "type": "object", + "required": ["base_currency", "target_currency"], + "properties": { + "base_currency": { + "description": "The base currency code, e.g. USD", + "type": "string" + }, + "target_currency": { + "description": "The target currency code, e.g. EUR", + "type": "string" + } + } + } + }, + { + "name": "translate_text", + "description": "Translate text from one language to another", + "parameters": { + "type": "object", + "required": ["text", "target_language"], + "properties": { + "text": { + "description": "The text to translate", + "type": "string" + }, + "target_language": { + "description": "The target language code, e.g. fr, es", + "type": "string" + } + } + } + }, + { + "name": "summarize_text", + "description": "Summarize a long piece of text into a concise form", + "parameters": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "description": "The text to summarize", + "type": "string" + }, + "max_length": { + "description": "Maximum length of the summary", + "type": "integer" + } + } + } + }, + { + "name": "detect_language", + "description": "Detect the language of a given text", + "parameters": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "description": "The text whose language should be detected", + "type": "string" + } + } + } + }, + { + "name": "extract_keywords", + "description": "Extract important keywords from a block of text", + "parameters": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "description": "The input text", + "type": "string" + }, + "max_keywords": { + "description": "Maximum number of keywords to return", + "type": "integer" + } + } + } + }, + { + "name": "sentiment_analysis", + "description": "Analyze the sentiment of a given text", + "parameters": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "description": "The text to analyze", + "type": "string" + } + } + } + }, + { + "name": "generate_uuid", + "description": "Generate a random UUID", + "parameters": { + "type": "object", + "properties": {} + } + }, + { + "name": "check_url_status", + "description": "Check if a URL is accessible and return its HTTP status", + "parameters": { + "type": "object", + "required": ["url"], + "properties": { + "url": { + "description": "The URL to check", + "type": "string" + } + } + } + }, + { + "name": "calculate", + "description": "Perform basic mathematical calculations", + "parameters": { + "type": "object", + "required": ["expression"], + "properties": { + "expression": { + "description": "Mathematical expression to evaluate, e.g. '2 + 2'", + "type": "string" + } + } + } + } +] + ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL] # Embeddings Test Data @@ -223,6 +418,68 @@ class Config: } ] +# Prompt Caching Test Data +PROMPT_CACHING_LARGE_CONTEXT = """You are an AI assistant tasked with analyzing legal documents. +Here is a detailed legal framework for contract analysis: + +1. CONTRACT FORMATION: A contract requires offer, acceptance, and consideration. The offer must be +definite and communicated to the offeree. Acceptance must mirror the terms of the offer (mirror image +rule). Consideration is the bargained-for exchange that makes the contract legally binding. Both parties +must have the legal capacity to enter into a contract, and the contract's purpose must be legal. + +2. WARRANTIES: Express warranties are explicit promises made by the seller about the product or service. +Implied warranties include the warranty of merchantability (the product is fit for ordinary purposes) and +the warranty of fitness for a particular purpose (the product is suitable for a specific buyer's needs). +These warranties provide guarantees about product or service quality and can be the basis for breach of +contract claims. + +3. LIMITATION OF LIABILITY: These clauses limit the amount or types of damages that can be recovered in +case of breach. They may cap damages at a specific amount, exclude certain types of damages (like +consequential or punitive damages), or limit liability to repair or replacement of defective goods. Courts +scrutinize these clauses carefully and may refuse to enforce them if they are unconscionable or against +public policy. + +4. INDEMNIFICATION: Indemnification clauses require one party to compensate the other for losses, damages, +or liabilities arising from specified events or claims. These provisions allocate risk between parties and +are particularly important in contracts involving potential third-party claims. The scope of indemnification +can vary widely, from narrow protection for specific claims to broad coverage for any losses arising from +the relationship. + +5. TERMINATION: Contract termination provisions specify the conditions under which either party may end the +contractual relationship. These may include termination for cause (breach of contract), termination for +convenience (with or without notice), automatic termination upon certain events, or mutual agreement. +Termination clauses often address notice requirements, cure periods, and the parties' rights and obligations +upon termination. + +6. DISPUTE RESOLUTION: These provisions establish the methods for resolving disagreements between parties. +Options include litigation in courts, arbitration (binding resolution by a neutral arbitrator), mediation +(facilitated negotiation), or a combination of methods. Arbitration clauses often specify the arbitration +rules (such as AAA or JAMS), the number of arbitrators, the location of arbitration, and whether arbitration +decisions are binding and final. + +7. FORCE MAJEURE: Force majeure clauses excuse performance when extraordinary events or circumstances beyond +the parties' control prevent fulfillment of contractual obligations. These events typically include natural +disasters, wars, pandemics, government actions, and other unforeseeable circumstances. The clause usually +defines what constitutes a force majeure event and specifies the parties' obligations during such events, +including notice requirements and efforts to mitigate damages. + +8. INTELLECTUAL PROPERTY: These provisions address rights related to patents, copyrights, trademarks, trade +secrets, and other intellectual property. They may cover ownership of pre-existing IP, IP created during the +contract term, licensing arrangements, and protection of proprietary information. IP clauses are crucial in +technology, creative works, and research and development contracts. + +9. CONFIDENTIALITY: Confidentiality provisions (also called non-disclosure clauses) impose obligations to +protect sensitive information shared between parties. They define what constitutes confidential information, +specify how it must be protected, limit its disclosure and use, and establish the duration of confidentiality +obligations. These clauses often survive contract termination and may include exceptions for information that +is publicly available or independently developed. + +10. GOVERNING LAW: Governing law clauses specify which jurisdiction's laws will apply to interpret and enforce +the contract. This is particularly important in contracts between parties in different states or countries. +The chosen jurisdiction's laws will govern issues like contract formation, performance, breach, and remedies. +These clauses often work in conjunction with venue or forum selection clauses that specify where disputes must +be resolved.""" * 3 # Repeat to ensure sufficient tokens (1024+ minimum) + # Gemini Reasoning Test Prompts GEMINI_REASONING_PROMPT = [ {