Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 211 additions & 1 deletion core/providers/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,62 @@ var (
}
)

// assertBedrockRequestEqual compares two BedrockConverseRequest objects
// but ignores the order of tools in ToolConfig
func assertBedrockRequestEqual(t *testing.T, expected, actual *bedrock.BedrockConverseRequest) {
t.Helper()

assert.Equal(t, expected.ModelID, actual.ModelID)
assert.Equal(t, expected.Messages, actual.Messages)
assert.Equal(t, expected.System, actual.System)
assert.Equal(t, expected.InferenceConfig, actual.InferenceConfig)
assert.Equal(t, expected.GuardrailConfig, actual.GuardrailConfig)
assert.Equal(t, expected.AdditionalModelRequestFields, actual.AdditionalModelRequestFields)
assert.Equal(t, expected.AdditionalModelResponseFieldPaths, actual.AdditionalModelResponseFieldPaths)
assert.Equal(t, expected.PerformanceConfig, actual.PerformanceConfig)
assert.Equal(t, expected.PromptVariables, actual.PromptVariables)
assert.Equal(t, expected.RequestMetadata, actual.RequestMetadata)
assert.Equal(t, expected.ServiceTier, actual.ServiceTier)
assert.Equal(t, expected.Stream, actual.Stream)
assert.Equal(t, expected.ExtraParams, actual.ExtraParams)
assert.Equal(t, expected.Fallbacks, actual.Fallbacks)

if expected.ToolConfig == nil {
assert.Nil(t, actual.ToolConfig)
return
}

require.NotNil(t, actual.ToolConfig)
assert.Equal(t, expected.ToolConfig.ToolChoice, actual.ToolConfig.ToolChoice)

expectedTools := expected.ToolConfig.Tools
actualTools := actual.ToolConfig.Tools

assert.Equal(t, len(expectedTools), len(actualTools), "Tool count mismatch")

expectedToolMap := make(map[string]bedrock.BedrockTool)
for _, tool := range expectedTools {
if tool.ToolSpec != nil {
expectedToolMap[tool.ToolSpec.Name] = tool
}
}

actualToolMap := make(map[string]bedrock.BedrockTool)
for _, tool := range actualTools {
if tool.ToolSpec != nil {
actualToolMap[tool.ToolSpec.Name] = tool
}
}

for name, expectedTool := range expectedToolMap {
actualTool, exists := actualToolMap[name]
assert.True(t, exists, "Tool %s not found in actual tools", name)
if exists {
assert.Equal(t, expectedTool, actualTool, "Tool %s differs", name)
}
}
}

func TestBedrock(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -448,6 +504,156 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) {
AdditionalModelResponseFieldPaths: []string{"field1", "field2"},
},
},
{
name: "ParallelToolCalls",
input: &schemas.BifrostChatRequest{
Model: "claude-3-sonnet",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("Invoke all tools in parallel that are available to you"),
},
},
{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("I'll invoke both available tools in parallel for you."),
},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Index: 0,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("tooluse_Yl388l8ES0G_3TQtDcKq_g"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("hello"),
Arguments: "{}",
},
},
{
Index: 1,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("tooluse_eARDw2iqRXak8uyRC2KxXw"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("world"),
Arguments: "{}",
},
},
},
},
},
{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("Hello"),
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: schemas.Ptr("tooluse_Yl388l8ES0G_3TQtDcKq_g"),
},
},
{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("World"),
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: schemas.Ptr("tooluse_eARDw2iqRXak8uyRC2KxXw"),
},
},
},
},
expected: &bedrock.BedrockConverseRequest{
ModelID: "claude-3-sonnet",
Messages: []bedrock.BedrockMessage{
{
Role: bedrock.BedrockMessageRoleUser,
Content: []bedrock.BedrockContentBlock{
{
Text: schemas.Ptr("Invoke all tools in parallel that are available to you"),
},
},
},
{
Role: bedrock.BedrockMessageRoleAssistant,
Content: []bedrock.BedrockContentBlock{
{
Text: schemas.Ptr("I'll invoke both available tools in parallel for you."),
},
{
ToolUse: &bedrock.BedrockToolUse{
ToolUseID: "tooluse_Yl388l8ES0G_3TQtDcKq_g",
Name: "hello",
Input: map[string]interface{}{},
},
},
{
ToolUse: &bedrock.BedrockToolUse{
ToolUseID: "tooluse_eARDw2iqRXak8uyRC2KxXw",
Name: "world",
Input: map[string]interface{}{},
},
},
},
},
{
Role: bedrock.BedrockMessageRoleUser,
Content: []bedrock.BedrockContentBlock{
{
ToolResult: &bedrock.BedrockToolResult{
ToolUseID: "tooluse_Yl388l8ES0G_3TQtDcKq_g",
Content: []bedrock.BedrockContentBlock{
{
Text: schemas.Ptr("Hello"),
},
},
Status: schemas.Ptr("success"),
},
},
{
ToolResult: &bedrock.BedrockToolResult{
ToolUseID: "tooluse_eARDw2iqRXak8uyRC2KxXw",
Content: []bedrock.BedrockContentBlock{
{
Text: schemas.Ptr("World"),
},
},
Status: schemas.Ptr("success"),
},
},
},
},
},
ToolConfig: &bedrock.BedrockToolConfig{
Tools: []bedrock.BedrockTool{
{
ToolSpec: &bedrock.BedrockToolSpec{
Name: "hello",
Description: schemas.Ptr("Tool extracted from conversation history"),
InputSchema: bedrock.BedrockToolInputSchema{
JSON: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
},
},
},
},
{
ToolSpec: &bedrock.BedrockToolSpec{
Name: "world",
Description: schemas.Ptr("Tool extracted from conversation history"),
InputSchema: bedrock.BedrockToolInputSchema{
JSON: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
},
},
},
},
},
},
},
},
{
name: "NilRequest",
input: nil,
Expand Down Expand Up @@ -478,7 +684,11 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) {
}
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, actual)
if tt.name == "ParallelToolCalls" {
assertBedrockRequestEqual(t, tt.expected, actual)
} else {
assert.Equal(t, tt.expected, actual)
}
}
})
}
Expand Down
Loading