Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,6 @@ abstract class ExecutorIntegrationTestBase {
model.provider !== LLMProvider.OpenRouter,
"KG-626 Error from OpenRouter on a streaming with a tool call"
)
assumeTrue(
model.provider !== LLMProvider.Bedrock,
"KG-627 Error from Bedrock executor on a streaming with a tool call"
)

val executor = getExecutor(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.filterNot
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.transform
import kotlinx.coroutines.withContext
Expand Down Expand Up @@ -319,40 +320,57 @@ public class BedrockLLMClient(
logger.error(exception) { exception.message }
close(exception)
}
}.map { chunkJsonString ->
try {
if (chunkJsonString.isBlank()) return@map emptyList()
when (modelFamily) {
is BedrockModelFamilies.AI21Jamba -> BedrockAI21JambaSerialization.parseJambaStreamChunk(
chunkJsonString
)
}.filterNot {
it.isBlank()
}.run {
when (modelFamily) {
is BedrockModelFamilies.AI21Jamba -> genericProcessStream(
this,
BedrockAI21JambaSerialization::parseJambaStreamChunk
)

is BedrockModelFamilies.AmazonNova -> BedrockAmazonNovaSerialization.parseNovaStreamChunk(
chunkJsonString
)
is BedrockModelFamilies.AmazonNova -> genericProcessStream(
this,
BedrockAmazonNovaSerialization::parseNovaStreamChunk
)

is BedrockModelFamilies.AnthropicClaude -> BedrockAnthropicClaudeSerialization.parseAnthropicStreamChunk(
chunkJsonString
)
is BedrockModelFamilies.Meta -> genericProcessStream(
this,
BedrockMetaLlamaSerialization::parseLlamaStreamChunk
)

is BedrockModelFamilies.Meta -> BedrockMetaLlamaSerialization.parseLlamaStreamChunk(chunkJsonString)
is BedrockModelFamilies.AnthropicClaude -> BedrockAnthropicClaudeSerialization.transformAnthropicStreamChunks(
chunkJsonStringFlow = this,
clock = clock,
)

is BedrockModelFamilies.TitanEmbedding, is BedrockModelFamilies.Cohere ->
throw LLMClientException(
clientName,
"Embedding models do not support streaming chat completions. Use embed() instead."
)
}
is BedrockModelFamilies.TitanEmbedding, is BedrockModelFamilies.Cohere ->
throw LLMClientException(
clientName,
"Embedding models do not support streaming chat completions. Use embed() instead."
)
}
}
}

/**
* Processes a flow of JSON strings into StreamFrames using the provided processor function.
* Handles exceptions by logging and re-throwing them.
*/
private fun genericProcessStream(
chunkJsonStringFlow: Flow<String>,
processor: (String) -> List<StreamFrame>
): Flow<StreamFrame> =
chunkJsonStringFlow.map { chunkJsonString ->
try {
processor(chunkJsonString)
} catch (e: Exception) {
logger.warn(e) { "Failed to parse Bedrock stream chunk: $chunkJsonString" }
throw e
}
}.transform { frames ->
frames.forEach {
emit(it)
}
frames.forEach { emit(it) }
}
}

override suspend fun embed(text: String, model: LLModel): List<Double> {
model.requireCapability(LLMCapability.Embed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNames
import kotlinx.serialization.json.JsonObject

/**
Expand Down Expand Up @@ -205,7 +204,7 @@ public data class BedrockAnthropicResponse(
val role: String,
val content: List<AnthropicContent>,
val model: String,
@JsonNames("stop_reason") val stopReason: String? = null,
val stopReason: String? = null,
val usage: BedrockAnthropicUsage? = null
)

Expand All @@ -220,6 +219,6 @@ public data class BedrockAnthropicResponse(
*/
@Serializable
public data class BedrockAnthropicUsage(
@SerialName("input_tokens") @JsonNames("inputTokens", "input_tokens") val inputTokens: Int,
@SerialName("output_tokens") @JsonNames("outputTokens", "output_tokens") val outputTokens: Int
val inputTokens: Int,
val outputTokens: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.anthropic
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicContent
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamDeltaContentType
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamEventType
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamResponse
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicUsage
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModel
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelContent
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelMessage
Expand All @@ -15,9 +18,12 @@ import ai.koog.prompt.message.Message
import ai.koog.prompt.message.ResponseMetaInfo
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.streaming.buildStreamFrameFlow
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.flow.Flow
import kotlinx.datetime.Clock
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonNamingStrategy
import kotlinx.serialization.json.buildJsonArray
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.encodeToJsonElement
Expand All @@ -31,6 +37,7 @@ internal object BedrockAnthropicClaudeSerialization {
ignoreUnknownKeys = true
isLenient = true
explicitNulls = false
namingStrategy = JsonNamingStrategy.SnakeCase
}

private fun buildMessagesHistory(prompt: Prompt): MutableList<BedrockAnthropicInvokeModelMessage> {
Expand Down Expand Up @@ -217,72 +224,103 @@ internal object BedrockAnthropicClaudeSerialization {
}
}

internal fun parseAnthropicStreamChunk(chunkJsonString: String, clock: Clock = Clock.System): List<StreamFrame> {
val streamResponse = json.decodeFromString<AnthropicStreamResponse>(chunkJsonString)

return when (streamResponse.type) {
"content_block_delta" -> {
streamResponse.delta?.let {
buildList {
it.text?.let(StreamFrame::Append)?.let(::add)
it.toolUse?.let { toolUse ->
StreamFrame.ToolCall(
id = toolUse.id,
name = toolUse.name,
content = toolUse.input.toString()
internal fun transformAnthropicStreamChunks(
chunkJsonStringFlow: Flow<String>,
clock: Clock = Clock.System
): Flow<StreamFrame> = buildStreamFrameFlow {
var inputTokens: Int? = null
var outputTokens: Int? = null

fun updateUsage(usage: AnthropicUsage) {
inputTokens = usage.inputTokens ?: inputTokens
outputTokens = usage.outputTokens ?: outputTokens
}

fun getMetaInfo(): ResponseMetaInfo = ResponseMetaInfo.create(
clock = clock,
totalTokensCount = inputTokens?.plus(outputTokens ?: 0) ?: outputTokens,
inputTokensCount = inputTokens,
outputTokensCount = outputTokens,
)

chunkJsonStringFlow.collect { chunkJsonString ->
val response = json.decodeFromString<AnthropicStreamResponse>(chunkJsonString)

when (response.type) {
AnthropicStreamEventType.MESSAGE_START.value -> {
response.message?.usage?.let(::updateUsage)
}

AnthropicStreamEventType.CONTENT_BLOCK_START.value -> {
when (val contentBlock = response.contentBlock) {
is AnthropicContent.Text -> {
emitAppend(contentBlock.text)
}

is AnthropicContent.ToolUse -> {
upsertToolCall(
index = response.index ?: error("Tool index is missing"),
id = contentBlock.id,
name = contentBlock.name,
)
}?.let(::add)
}

else -> {
contentBlock?.let { logger.warn { "Unknown Anthropic stream content block type: ${it::class}" } }
?: logger.warn { "Anthropic stream content block is missing" }
}
}
} ?: emptyList()
}
}

"message_delta" -> {
streamResponse.message?.content?.map { content ->
when (content) {
is AnthropicContent.Text ->
StreamFrame.Append(content.text)
AnthropicStreamEventType.CONTENT_BLOCK_DELTA.value -> {
response.delta?.let { delta ->
// Handles deltas for tool calls and text

is AnthropicContent.Thinking ->
StreamFrame.Append(content.thinking)
when (delta.type) {
AnthropicStreamDeltaContentType.INPUT_JSON_DELTA.value -> {
upsertToolCall(
index = response.index ?: error("Tool index is missing"),
args = delta.partialJson ?: error("Tool args are missing")
)
}

is AnthropicContent.ToolUse ->
StreamFrame.ToolCall(
id = content.id,
name = content.name,
content = content.input.toString()
)
AnthropicStreamDeltaContentType.TEXT_DELTA.value -> {
emitAppend(
delta.text ?: error("Text delta is missing")
)
}

else -> throw IllegalArgumentException(
"Unsupported AnthropicContent type in message_delta. Content: $content"
)
else -> {
logger.warn { "Unknown Anthropic stream delta type: ${delta.type}" }
}
}
}
} ?: emptyList()
}
}

"message_start" -> {
val inputTokens = streamResponse.message?.usage?.inputTokens
logger.debug { "Bedrock stream starts. Input tokens: $inputTokens" }
emptyList()
}
AnthropicStreamEventType.CONTENT_BLOCK_STOP.value -> {
tryEmitPendingToolCall()
}

"message_stop" -> {
val inputTokens = streamResponse.message?.usage?.inputTokens
val outputTokens = streamResponse.message?.usage?.outputTokens
logger.debug { "Bedrock stream stops. Output tokens: $outputTokens" }
listOf(
StreamFrame.End(
finishReason = streamResponse.message?.stopReason,
metaInfo = ResponseMetaInfo.create(
clock = clock,
totalTokensCount = inputTokens?.let { it + (outputTokens ?: 0) } ?: outputTokens,
inputTokensCount = inputTokens,
outputTokensCount = outputTokens
)
AnthropicStreamEventType.MESSAGE_DELTA.value -> {
response.usage?.let(::updateUsage)
emitEnd(
finishReason = response.delta?.stopReason,
metaInfo = getMetaInfo()
)
)
}
}

AnthropicStreamEventType.MESSAGE_STOP.value -> {
logger.debug { "Received stop message event from Anthropic" }
}

else -> emptyList()
AnthropicStreamEventType.ERROR.value -> {
error("Anthropic error: ${response.error}")
}

AnthropicStreamEventType.PING.value -> {
logger.debug { "Received ping from Anthropic" }
}
}
}
}
}
Loading