From 9be46bad6fb9c311d1b52dc2fe81b2101b4e09e0 Mon Sep 17 00:00:00 2001 From: Nathaniel Lombardi Date: Wed, 31 Dec 2025 14:35:55 -0500 Subject: [PATCH] Parse tool usage in bedrock anthropic streaming --- .../executor/ExecutorIntegrationTestBase.kt | 4 - .../clients/bedrock/BedrockLLMClient.kt | 66 +++-- .../modelfamilies/BedrockDataClasses.kt | 7 +- .../BedrockAnthropicClaudeSerialization.kt | 148 ++++++---- ...BedrockAnthropicClaudeSerializationTest.kt | 253 +++++++++++------- 5 files changed, 291 insertions(+), 187 deletions(-) diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt index 7f6397afee..7db0d74153 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt @@ -1132,10 +1132,6 @@ abstract class ExecutorIntegrationTestBase { model.provider !== LLMProvider.OpenRouter, "KG-626 Error from OpenRouter on a streaming with a tool call" ) - assumeTrue( - model.provider !== LLMProvider.Bedrock, - "KG-627 Error from Bedrock executor on a streaming with a tool call" - ) val executor = getExecutor(model) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt index 2e9a7b5d50..fe74730a53 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt @@ -54,6 +54,7 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.filterNot import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.transform import kotlinx.coroutines.withContext @@ -319,40 +320,57 @@ public class BedrockLLMClient( logger.error(exception) { exception.message } close(exception) } - }.map { chunkJsonString -> - try { - if (chunkJsonString.isBlank()) return@map emptyList() - when (modelFamily) { - is BedrockModelFamilies.AI21Jamba -> BedrockAI21JambaSerialization.parseJambaStreamChunk( - chunkJsonString - ) + }.filterNot { + it.isBlank() + }.run { + when (modelFamily) { + is BedrockModelFamilies.AI21Jamba -> genericProcessStream( + this, + BedrockAI21JambaSerialization::parseJambaStreamChunk + ) - is BedrockModelFamilies.AmazonNova -> BedrockAmazonNovaSerialization.parseNovaStreamChunk( - chunkJsonString - ) + is BedrockModelFamilies.AmazonNova -> genericProcessStream( + this, + BedrockAmazonNovaSerialization::parseNovaStreamChunk + ) - is BedrockModelFamilies.AnthropicClaude -> BedrockAnthropicClaudeSerialization.parseAnthropicStreamChunk( - chunkJsonString - ) + is BedrockModelFamilies.Meta -> genericProcessStream( + this, + BedrockMetaLlamaSerialization::parseLlamaStreamChunk + ) - is BedrockModelFamilies.Meta -> BedrockMetaLlamaSerialization.parseLlamaStreamChunk(chunkJsonString) + is BedrockModelFamilies.AnthropicClaude -> BedrockAnthropicClaudeSerialization.transformAnthropicStreamChunks( + chunkJsonStringFlow = this, + clock = clock, + ) - is BedrockModelFamilies.TitanEmbedding, is BedrockModelFamilies.Cohere -> - throw LLMClientException( - clientName, - "Embedding models do not support streaming chat completions. Use embed() instead." - ) - } + is BedrockModelFamilies.TitanEmbedding, is BedrockModelFamilies.Cohere -> + throw LLMClientException( + clientName, + "Embedding models do not support streaming chat completions. Use embed() instead." + ) + } + } + } + + /** + * Processes a flow of JSON strings into StreamFrames using the provided processor function. + * Handles exceptions by logging and re-throwing them. + */ + private fun genericProcessStream( + chunkJsonStringFlow: Flow, + processor: (String) -> List + ): Flow = + chunkJsonStringFlow.map { chunkJsonString -> + try { + processor(chunkJsonString) } catch (e: Exception) { logger.warn(e) { "Failed to parse Bedrock stream chunk: $chunkJsonString" } throw e } }.transform { frames -> - frames.forEach { - emit(it) - } + frames.forEach { emit(it) } } - } override suspend fun embed(text: String, model: LLModel): List { model.requireCapability(LLMCapability.Embed) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/BedrockDataClasses.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/BedrockDataClasses.kt index f8872df903..954eaccde2 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/BedrockDataClasses.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/BedrockDataClasses.kt @@ -5,7 +5,6 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonNames import kotlinx.serialization.json.JsonObject /** @@ -205,7 +204,7 @@ public data class BedrockAnthropicResponse( val role: String, val content: List, val model: String, - @JsonNames("stop_reason") val stopReason: String? = null, + val stopReason: String? = null, val usage: BedrockAnthropicUsage? = null ) @@ -220,6 +219,6 @@ public data class BedrockAnthropicResponse( */ @Serializable public data class BedrockAnthropicUsage( - @SerialName("input_tokens") @JsonNames("inputTokens", "input_tokens") val inputTokens: Int, - @SerialName("output_tokens") @JsonNames("outputTokens", "output_tokens") val outputTokens: Int + val inputTokens: Int, + val outputTokens: Int ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt index 9b55eb669e..843ceabbfb 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt @@ -3,7 +3,10 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.anthropic import ai.koog.agents.core.tools.ToolDescriptor import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.executor.clients.anthropic.models.AnthropicContent +import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamDeltaContentType +import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamEventType import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamResponse +import ai.koog.prompt.executor.clients.anthropic.models.AnthropicUsage import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModel import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelContent import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelMessage @@ -15,9 +18,12 @@ import ai.koog.prompt.message.Message import ai.koog.prompt.message.ResponseMetaInfo import ai.koog.prompt.params.LLMParams import ai.koog.prompt.streaming.StreamFrame +import ai.koog.prompt.streaming.buildStreamFrameFlow import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.flow.Flow import kotlinx.datetime.Clock import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonNamingStrategy import kotlinx.serialization.json.buildJsonArray import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.encodeToJsonElement @@ -31,6 +37,7 @@ internal object BedrockAnthropicClaudeSerialization { ignoreUnknownKeys = true isLenient = true explicitNulls = false + namingStrategy = JsonNamingStrategy.SnakeCase } private fun buildMessagesHistory(prompt: Prompt): MutableList { @@ -217,72 +224,103 @@ internal object BedrockAnthropicClaudeSerialization { } } - internal fun parseAnthropicStreamChunk(chunkJsonString: String, clock: Clock = Clock.System): List { - val streamResponse = json.decodeFromString(chunkJsonString) - - return when (streamResponse.type) { - "content_block_delta" -> { - streamResponse.delta?.let { - buildList { - it.text?.let(StreamFrame::Append)?.let(::add) - it.toolUse?.let { toolUse -> - StreamFrame.ToolCall( - id = toolUse.id, - name = toolUse.name, - content = toolUse.input.toString() + internal fun transformAnthropicStreamChunks( + chunkJsonStringFlow: Flow, + clock: Clock = Clock.System + ): Flow = buildStreamFrameFlow { + var inputTokens: Int? = null + var outputTokens: Int? = null + + fun updateUsage(usage: AnthropicUsage) { + inputTokens = usage.inputTokens ?: inputTokens + outputTokens = usage.outputTokens ?: outputTokens + } + + fun getMetaInfo(): ResponseMetaInfo = ResponseMetaInfo.create( + clock = clock, + totalTokensCount = inputTokens?.plus(outputTokens ?: 0) ?: outputTokens, + inputTokensCount = inputTokens, + outputTokensCount = outputTokens, + ) + + chunkJsonStringFlow.collect { chunkJsonString -> + val response = json.decodeFromString(chunkJsonString) + + when (response.type) { + AnthropicStreamEventType.MESSAGE_START.value -> { + response.message?.usage?.let(::updateUsage) + } + + AnthropicStreamEventType.CONTENT_BLOCK_START.value -> { + when (val contentBlock = response.contentBlock) { + is AnthropicContent.Text -> { + emitAppend(contentBlock.text) + } + + is AnthropicContent.ToolUse -> { + upsertToolCall( + index = response.index ?: error("Tool index is missing"), + id = contentBlock.id, + name = contentBlock.name, ) - }?.let(::add) + } + + else -> { + contentBlock?.let { logger.warn { "Unknown Anthropic stream content block type: ${it::class}" } } + ?: logger.warn { "Anthropic stream content block is missing" } + } } - } ?: emptyList() - } + } - "message_delta" -> { - streamResponse.message?.content?.map { content -> - when (content) { - is AnthropicContent.Text -> - StreamFrame.Append(content.text) + AnthropicStreamEventType.CONTENT_BLOCK_DELTA.value -> { + response.delta?.let { delta -> + // Handles deltas for tool calls and text - is AnthropicContent.Thinking -> - StreamFrame.Append(content.thinking) + when (delta.type) { + AnthropicStreamDeltaContentType.INPUT_JSON_DELTA.value -> { + upsertToolCall( + index = response.index ?: error("Tool index is missing"), + args = delta.partialJson ?: error("Tool args are missing") + ) + } - is AnthropicContent.ToolUse -> - StreamFrame.ToolCall( - id = content.id, - name = content.name, - content = content.input.toString() - ) + AnthropicStreamDeltaContentType.TEXT_DELTA.value -> { + emitAppend( + delta.text ?: error("Text delta is missing") + ) + } - else -> throw IllegalArgumentException( - "Unsupported AnthropicContent type in message_delta. Content: $content" - ) + else -> { + logger.warn { "Unknown Anthropic stream delta type: ${delta.type}" } + } + } } - } ?: emptyList() - } + } - "message_start" -> { - val inputTokens = streamResponse.message?.usage?.inputTokens - logger.debug { "Bedrock stream starts. Input tokens: $inputTokens" } - emptyList() - } + AnthropicStreamEventType.CONTENT_BLOCK_STOP.value -> { + tryEmitPendingToolCall() + } - "message_stop" -> { - val inputTokens = streamResponse.message?.usage?.inputTokens - val outputTokens = streamResponse.message?.usage?.outputTokens - logger.debug { "Bedrock stream stops. Output tokens: $outputTokens" } - listOf( - StreamFrame.End( - finishReason = streamResponse.message?.stopReason, - metaInfo = ResponseMetaInfo.create( - clock = clock, - totalTokensCount = inputTokens?.let { it + (outputTokens ?: 0) } ?: outputTokens, - inputTokensCount = inputTokens, - outputTokensCount = outputTokens - ) + AnthropicStreamEventType.MESSAGE_DELTA.value -> { + response.usage?.let(::updateUsage) + emitEnd( + finishReason = response.delta?.stopReason, + metaInfo = getMetaInfo() ) - ) - } + } + + AnthropicStreamEventType.MESSAGE_STOP.value -> { + logger.debug { "Received stop message event from Anthropic" } + } - else -> emptyList() + AnthropicStreamEventType.ERROR.value -> { + error("Anthropic error: ${response.error}") + } + + AnthropicStreamEventType.PING.value -> { + logger.debug { "Received ping from Anthropic" } + } + } } } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt index 056ea9346c..ee45c99aeb 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt @@ -8,11 +8,13 @@ import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInv import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelContent import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelMessage import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicToolChoice -import ai.koog.prompt.executor.clients.bedrock.modelfamilies.anthropic.BedrockAnthropicClaudeSerialization.parseAnthropicStreamChunk import ai.koog.prompt.message.Message import ai.koog.prompt.message.ResponseMetaInfo import ai.koog.prompt.params.LLMParams import ai.koog.prompt.streaming.StreamFrame +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlinx.serialization.json.JsonObject @@ -176,10 +178,10 @@ class BedrockAnthropicClaudeSerializationTest { } ], "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "stopReason": "$stopReason", + "stop_reason": "$stopReason", "usage": { - "inputTokens": 25, - "outputTokens": 20 + "input_tokens": 25, + "output_tokens": 20 } } """.trimIndent() @@ -219,10 +221,10 @@ class BedrockAnthropicClaudeSerializationTest { } ], "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "stopReason": "tool_use", + "stop_reason": "tool_use", "usage": { - "inputTokens": 25, - "outputTokens": 15 + "input_tokens": 25, + "output_tokens": 15 } } """.trimIndent() @@ -268,10 +270,10 @@ class BedrockAnthropicClaudeSerializationTest { } ], "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "stopReason": "tool_use", + "stop_reason": "tool_use", "usage": { - "inputTokens": 25, - "outputTokens": 30 + "input_tokens": 25, + "output_tokens": 30 } } """.trimIndent() @@ -292,112 +294,163 @@ class BedrockAnthropicClaudeSerializationTest { } @Test - fun `parseAnthropicStreamChunk with content_block_delta`() { - val chunkJson = """ - { - "type": "content_block_delta", - "index": 0, - "delta": { - "type": "text_delta", - "text": "Paris is " + fun `transformAnthropicStreamChunks with simple message`() = runTest { + val chunkJsonStringFlow = flowOf( + """ + { + "type" : "content_block_start", + "index" : 0, + "content_block" : { + "type" : "text", + "text" : "hello" + } } - } - """.trimIndent() + """.trimIndent(), + """ + { + "type" : "content_block_delta", + "index" : 0, + "delta" : { + "type" : "text_delta", + "text" : "world" + } + } + """.trimIndent(), + """ + { + "type" : "content_block_stop", + "index" : 0 + } + """.trimIndent(), + ) - val content = parseAnthropicStreamChunk(chunkJson) - assertEquals(listOf("Paris is ").map(StreamFrame::Append), content) + val content = + BedrockAnthropicClaudeSerialization.transformAnthropicStreamChunks(chunkJsonStringFlow, mockClock).toList() + val expected = listOf( + StreamFrame.Append("hello"), + StreamFrame.Append("world"), + ) + assertEquals(expected, content) } @Test - fun `parseAnthropicStreamChunk with message_delta`() { - val chunkJson = """ - { - "type": "message_delta", - "delta": { - "type": "text_delta", - "stopReason": "end_turn" - }, - "message": { - "id": "msg_01234567", - "type": "message", - "role": "assistant", - "content": [ - { - "type": "text", - "text": "the capital of France." + fun `transformAnthropicStreamChunks with metainfo`() = runTest { + val stopReason = "end_turn" + val chunkJsonStringFlow = flowOf( + """ + { + "type" : "message_start", + "message" : { + "model" : "claude-3-5-haiku-20241022", + "id" : "msg_12345", + "type" : "message", + "role" : "assistant", + "content" : [ ], + "stop_reason" : null, + "stop_sequence" : null, + "usage" : { + "input_tokens" : 22, + "cache_creation_input_tokens" : 0, + "cache_read_input_tokens" : 0, + "output_tokens" : 3 } - ], - "model": "anthropic.claude-3-sonnet-20240229-v1:0" + } } - } - """.trimIndent() - - val content = parseAnthropicStreamChunk(chunkJson) - assertEquals(listOf("the capital of France.").map(StreamFrame::Append), content) - } - - @Test - fun `parseAnthropicStreamChunk with message_start`() { - val chunkJson = """ - { - "type": "message_start", - "message": { - "id": "msg_01234567", - "type": "message", - "role": "assistant", - "content": [], - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "usage": { - "inputTokens": 25, - "outputTokens": 0 + """.trimIndent(), + """ + { + "type" : "message_delta", + "delta" : { + "stop_reason" : "$stopReason", + "stop_sequence" : null + }, + "usage" : { + "output_tokens" : 13 } } - } - """.trimIndent() + """.trimIndent(), + """ + { + "type" : "message_stop", + "amazon-bedrock-invocationMetrics" : { + "inputTokenCount" : 22, + "outputTokenCount" : 13, + "invocationLatency" : 536, + "firstByteLatency" : 421 + } + } + """.trimIndent() + ) - val content = parseAnthropicStreamChunk(chunkJson) - assertEquals(emptyList(), content) + val content = + BedrockAnthropicClaudeSerialization.transformAnthropicStreamChunks(chunkJsonStringFlow, mockClock).toList() + val expected = listOf( + StreamFrame.End( + finishReason = stopReason, + metaInfo = ResponseMetaInfo.create( + clock = mockClock, + totalTokensCount = 35, + inputTokensCount = 22, + outputTokensCount = 13 + ) + ) + ) + assertEquals(expected, content) } @Test - fun `parseAnthropicStreamChunk with message_stop`() { - val chunkJson = """ - { - "type": "message_stop", - "message": { - "id": "msg_01234567", - "type": "message", - "role": "assistant", - "content": [ - { - "type": "text", - "text": "Paris is the capital of France." - } - ], - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "stopReason": "end_turn", - "usage": { - "inputTokens": 25, - "outputTokens": 20 + fun `transformAnthropicStreamChunks with single tool call`() = runTest { + val chunkJsonStringFlow = flowOf( + """ + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "tool_use", + "id": "$toolId", + "name": "$toolName", + "input": {} } } - } - """.trimIndent() - val content = parseAnthropicStreamChunk(chunkJson, mockClock) - assertEquals( - expected = listOf( - StreamFrame.End( - finishReason = "end_turn", - metaInfo = ResponseMetaInfo.create( - clock = mockClock, - totalTokensCount = 45, - inputTokensCount = 25, - outputTokensCount = 20 - ) - ) - ), - actual = content + """.trimIndent(), + """ + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "input_json_delta", + "partial_json": "{\"location\":" + } + } + """.trimIndent(), + """ + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "input_json_delta", + "partial_json": "\"Paris\"}" + } + } + """.trimIndent(), + """ + { + "type": "content_block_stop", + "index": 0 + } + """.trimIndent() + ) + + val content = + BedrockAnthropicClaudeSerialization.transformAnthropicStreamChunks(chunkJsonStringFlow, mockClock).toList() + val expected = listOf( + StreamFrame.ToolCall( + id = toolId, + name = toolName, + content = "{\"location\":\"Paris\"}" + ) ) + assertEquals(expected, content) } @Test