diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriber.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriber.kt index cf4ffcd18b..d4b21e981c 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriber.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriber.kt @@ -61,7 +61,10 @@ public interface ToolCallDescriber { buildJsonObject { message.id?.let { put("tool_call_id", JsonPrimitive(it)) } put("tool_name", JsonPrimitive(message.tool)) - put("tool_args", message.contentJson) + message.contentJsonResult.fold( + onSuccess = { put("tool_args", it) }, + onFailure = { put("tool_args_error", JsonPrimitive("Failed to parse tool arguments: $it")) } + ) } ), metaInfo = message.metaInfo diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentEdges.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentEdges.kt index 5277a84d6c..5798b5a083 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentEdges.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentEdges.kt @@ -8,6 +8,7 @@ import ai.koog.agents.core.environment.toSafeResult import ai.koog.agents.core.tools.Tool import ai.koog.prompt.message.ContentPart import ai.koog.prompt.message.Message +import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KClass /** @@ -98,7 +99,13 @@ public inline fun - val args = tool.decodeArgs(toolCall.contentJson) + val args = try { + tool.decodeArgs(toolCall.contentJsonResult.getOrNull() ?: return@onCondition false) + } catch (e: CancellationException) { + throw e + } catch (_: Exception) { + return@onCondition false + } block(args) } } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/ContextualAgentEnvironment.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/ContextualAgentEnvironment.kt index 8bee510fa1..e10cfb7b74 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/ContextualAgentEnvironment.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/ContextualAgentEnvironment.kt @@ -2,8 +2,11 @@ package ai.koog.agents.core.environment import ai.koog.agents.core.agent.context.AIAgentContext import ai.koog.agents.core.agent.execution.AgentExecutionInfo +import ai.koog.agents.core.feature.model.toAgentError import ai.koog.prompt.message.Message import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.serialization.json.JsonObject +import kotlin.coroutines.cancellation.CancellationException import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -20,6 +23,40 @@ public class ContextualAgentEnvironment( override suspend fun executeTool(toolCall: Message.Tool.Call): ReceivedToolResult { @OptIn(ExperimentalUuidApi::class) val eventId = Uuid.random().toString() + val toolDescription = context.llm.toolRegistry.getToolOrNull(toolCall.tool)?.descriptor?.description + + val toolArgs = try { + toolCall.contentJson + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + logger.error { "Failed to execute tool call with id '${toolCall.id}' while parsing args: ${e.message}" } + + val tool = toolCall.tool + val toolArgs = JsonObject(emptyMap()) + val message = "Failed to parse tool arguments: ${e.message}" + context.pipeline.onToolValidationFailed( + eventId = eventId, + executionInfo = context.executionInfo, + runId = context.runId, + toolCallId = toolCall.id, + toolName = tool, + toolDescription = toolDescription, + toolArgs = toolArgs, + message = message, + error = e.toAgentError(), + context = context + ) + return ReceivedToolResult( + id = toolCall.id, + tool = tool, + toolArgs = toolArgs, + toolDescription = null, + content = message, + resultKind = ToolResultKind.ValidationError(e.toAgentError()), + result = null + ) + } logger.trace { "Executing tool call (" + @@ -27,7 +64,7 @@ public class ContextualAgentEnvironment( "run id: ${context.runId}, " + "tool call id: ${toolCall.id}, " + "tool: ${toolCall.tool}, " + - "args: ${toolCall.contentJson})" + "args: $toolArgs)" } context.pipeline.onToolCallStarting( @@ -36,8 +73,8 @@ public class ContextualAgentEnvironment( runId = context.runId, toolCallId = toolCall.id, toolName = toolCall.tool, - toolDescription = context.llm.toolRegistry.getToolOrNull(toolCall.tool)?.descriptor?.description, - toolArgs = toolCall.contentJson, + toolDescription = toolDescription, + toolArgs = toolArgs, context = context ) @@ -52,7 +89,7 @@ public class ContextualAgentEnvironment( "tool call id: ${toolCall.id}, " + "tool: ${toolCall.tool}, " + "tool description: ${toolResult.toolDescription}, " + - "args: ${toolCall.contentJson}) " + + "args: $toolArgs) " + "with result: $toolResult" } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt index 2dec105e78..e35cc3e0fd 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt @@ -7,6 +7,8 @@ import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi import ai.koog.prompt.message.Message import io.github.oshai.kotlinlogging.KLogger +import kotlinx.serialization.json.JsonObject +import kotlin.coroutines.cancellation.CancellationException /** * Represents base agent environment with generic abstractions. @@ -19,7 +21,7 @@ public class GenericAgentEnvironment( override suspend fun executeTool(toolCall: Message.Tool.Call): ReceivedToolResult { logger.info { - formatLog("Executing tool (name: ${toolCall.tool}, args: ${toolCall.contentJson})") + formatLog("Executing tool (name: ${toolCall.tool}, args: ${toolCall.contentJsonResult.getOrElse { "Failed to parse tool arguments: ${it.message}" }})") } val environmentToolResult = processToolCall(toolCall) @@ -45,7 +47,21 @@ public class GenericAgentEnvironment( // Tool val id = toolCall.id val toolName = toolCall.tool - val toolArgsJson = toolCall.contentJson + val toolArgsJson = try { + toolCall.contentJson + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + return ReceivedToolResult( + id = id, + tool = toolName, + toolArgs = JsonObject(emptyMap()), + toolDescription = null, + content = "Tool with name '$toolName' failed to parse arguments due to the error: ${e.message}", + resultKind = ToolResultKind.Failure(e.toAgentError()), + result = null, + ) + } val tool = toolRegistry.getToolOrNull(toolName) ?: run { @@ -66,6 +82,8 @@ public class GenericAgentEnvironment( // Tool Args val toolArgs = try { tool.decodeArgs(toolArgsJson) + } catch (e: CancellationException) { + throw e } catch (e: Exception) { logger.error(e) { formatLog("Tool with name '$toolName' failed to parse arguments: $toolArgsJson") } return ReceivedToolResult( @@ -82,6 +100,8 @@ public class GenericAgentEnvironment( val toolResult = try { @Suppress("UNCHECKED_CAST") (tool as Tool).execute(toolArgs) + } catch (e: CancellationException) { + throw e } catch (e: ToolException) { return ReceivedToolResult( id = id, @@ -108,14 +128,31 @@ public class GenericAgentEnvironment( logger.trace { "Completed execution of the tool '$toolName' with result: $toolResult" } + val (content, result) = try { + tool.encodeResultToStringUnsafe(toolResult) to tool.encodeResult(toolResult) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + logger.error(e) { "Tool with name '$toolName' failed to encode result: $toolResult" } + return ReceivedToolResult( + id = id, + tool = toolName, + toolArgs = toolArgsJson, + toolDescription = toolDescription, + content = "Tool with name '$toolName' failed to serialize result due to the error: ${e.message}!", + resultKind = ToolResultKind.Failure(e.toAgentError()), + result = null + ) + } + return ReceivedToolResult( id = id, tool = toolName, toolArgs = toolArgsJson, toolDescription = toolDescription, - content = tool.encodeResultToStringUnsafe(toolResult), + content = content, resultKind = ToolResultKind.Success, - result = tool.encodeResult(toolResult) + result = result ) } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/SafeTool.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/SafeTool.kt index 9c09933dc9..995a7acce9 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/SafeTool.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/SafeTool.kt @@ -6,6 +6,7 @@ import ai.koog.agents.core.tools.Tool import ai.koog.prompt.message.Message import ai.koog.prompt.message.ResponseMetaInfo import kotlinx.datetime.Clock +import kotlin.coroutines.cancellation.CancellationException /** * A wrapper class designed to safely execute a tool within a given AI agent environment. @@ -190,11 +191,15 @@ public data class SafeTool( * @return A [SafeTool.Result] which will either be a [SafeTool.Result.Failure] or [SafeTool.Result.Success] * based on the presence and validity of the `result` in the [ReceivedToolResult]. */ -public fun ReceivedToolResult.toSafeResult(tool: Tool<*, TResult>): SafeTool.Result = when (result) { - null -> { - SafeTool.Result.Failure(message = content) - } - else -> { - SafeTool.Result.Success(result = tool.decodeResult(this.result), content = content) +public fun ReceivedToolResult.toSafeResult(tool: Tool<*, TResult>): SafeTool.Result { + val encodedResult = result ?: return SafeTool.Result.Failure(message = content) + val decodedResult = try { + tool.decodeResult(encodedResult) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + return SafeTool.Result.Failure("Tool with name '${tool.name}' failed to deserialize result with error: ${e.message}") } + + return SafeTool.Result.Success(result = decodedResult, content = content) } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt index d5a27f16ea..eb4b80b0d4 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt @@ -17,6 +17,7 @@ import ai.koog.agents.core.dsl.extension.setToolChoiceRequired import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.environment.ToolResultKind import ai.koog.agents.core.environment.toSafeResult +import ai.koog.agents.core.feature.model.toAgentError import ai.koog.agents.core.tools.Tool import ai.koog.agents.core.tools.ToolDescriptor import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi @@ -28,7 +29,9 @@ import ai.koog.prompt.message.Message import ai.koog.prompt.params.LLMParams import ai.koog.prompt.processor.ResponseProcessor import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.serializer +import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KClass /** @@ -586,7 +589,7 @@ public inline fun AIA edge( callToolsHacked forwardTo finalizeTask - onCondition { toolResults -> toolResults.firstOrNull()?.let { it.tool == finishTool.name } == true } + onCondition { toolResults -> toolResults.firstOrNull()?.let { it.tool == finishTool.name && it.resultKind is ToolResultKind.Success } == true } transformed { toolsResults -> toolsResults.first() } ) @@ -603,11 +606,26 @@ internal suspend fun AIAgentContext.executeFinishToo toolCall: Message.Tool.Call, finishTool: Tool, ): ReceivedToolResult { + val toolDescription = finishTool.descriptor.description // Execute Finish tool directly and get a result - val args = finishTool.decodeArgs(toolCall.contentJson) - val toolResult = finishTool.execute(args = args) + val encodedResult = try { + val args = finishTool.decodeArgs(toolCall.contentJson) + val toolResult = finishTool.execute(args = args) + finishTool.encodeResult(toolResult) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + return ReceivedToolResult( + id = toolCall.id, + tool = finishTool.name, + toolArgs = toolCall.contentJsonResult.getOrElse { JsonObject(emptyMap()) }, + toolDescription = toolDescription, + content = "Failed to execute '${finishTool.name}' with error: ${e.message}'", + resultKind = ToolResultKind.Failure(e.toAgentError()), + result = null, + ) + } - val encodedResult = finishTool.encodeResult(toolResult) // Append a final tool call result to the prompt for further LLM calls // to see it (otherwise they would fail) llm.writeSession { @@ -624,7 +642,7 @@ internal suspend fun AIAgentContext.executeFinishToo toolArgs = toolCall.contentJson, content = toolCall.content, resultKind = ToolResultKind.Success, - toolDescription = finishTool.descriptor.description, + toolDescription = toolDescription, result = encodedResult ) } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriberTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriberTest.kt index dade61a0d8..2f53f5c0f7 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriberTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/config/ToolCallDescriberTest.kt @@ -7,7 +7,6 @@ import kotlinx.datetime.Clock import kotlinx.datetime.Instant.Companion.fromEpochMilliseconds import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.assertFailsWith import kotlin.test.assertTrue class ToolCallDescriberTest { @@ -162,9 +161,12 @@ class ToolCallDescriberTest { metaInfo = ResponseMetaInfo.create(testClock), ) - assertFailsWith { - describer.describeToolCall(invalidJsonToolCall) - } + val result = describer.describeToolCall(invalidJsonToolCall) + + assertTrue(result.content.contains("\"tool_call_id\":\"test-call-id\"")) + assertTrue(result.content.contains("\"tool_name\":\"test-tool\"")) + assertTrue(result.content.contains("\"tool_args_error\":\"Failed to parse tool arguments:")) + assertEquals(invalidJsonToolCall.metaInfo, result.metaInfo) } @Test @@ -227,9 +229,13 @@ class ToolCallDescriberTest { metaInfo = ResponseMetaInfo.create(testClock), ) - assertFailsWith { - describer.describeToolCall(nullContentToolCall) - } + val result = describer.describeToolCall(nullContentToolCall) + + assertTrue(result.content.contains("\"tool_call_id\":\"test-call-id\"")) + assertTrue(result.content.contains("\"tool_name\":\"test-tool\"")) + assertTrue(result.content.contains("\"tool_args_error\":\"Failed to parse tool arguments:")) + assertTrue(result.content.contains("IllegalArgumentException")) + assertEquals(nullContentToolCall.metaInfo, result.metaInfo) } @Test @@ -281,9 +287,12 @@ class ToolCallDescriberTest { metaInfo = ResponseMetaInfo.create(testClock), ) - assertFailsWith { - describer.describeToolCall(nonJsonToolCall) - } + val result = describer.describeToolCall(nonJsonToolCall) + + assertTrue(result.content.contains("\"tool_call_id\":\"test-call-id\"")) + assertTrue(result.content.contains("\"tool_name\":\"test-tool\"")) + assertTrue(result.content.contains("\"tool_args_error\":\"Failed to parse tool arguments:")) + assertEquals(nonJsonToolCall.metaInfo, result.metaInfo) } @Test diff --git a/agents/agents-core/src/jvmMain/kotlin/ai/koog/agents/core/environment/SafeTool.jvm.kt b/agents/agents-core/src/jvmMain/kotlin/ai/koog/agents/core/environment/SafeTool.jvm.kt index 90200a3d88..0afa483a97 100644 --- a/agents/agents-core/src/jvmMain/kotlin/ai/koog/agents/core/environment/SafeTool.jvm.kt +++ b/agents/agents-core/src/jvmMain/kotlin/ai/koog/agents/core/environment/SafeTool.jvm.kt @@ -8,6 +8,7 @@ import ai.koog.agents.core.tools.reflect.asTool import ai.koog.prompt.message.Message import ai.koog.prompt.message.ResponseMetaInfo import kotlinx.datetime.Clock +import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KFunction /** @@ -187,11 +188,15 @@ public data class SafeToolFromCallable( * @return A `SafeToolFromCallable.Result` object, either a `Success` with the extracted result * and content or a `Failure` with an appropriate message. */ -private fun ReceivedToolResult.toSafeResultFromCallable(tool: Tool): SafeToolFromCallable.Result = - when (result) { - null -> SafeToolFromCallable.Result.Failure(message = content) - else -> SafeToolFromCallable.Result.Success( - result = tool.decodeResult(result), - content = content - ) +private fun ReceivedToolResult.toSafeResultFromCallable(tool: Tool): SafeToolFromCallable.Result { + val encodedResult = result ?: return SafeToolFromCallable.Result.Failure(message = content) + val decodedResult = try { + tool.decodeResult(encodedResult) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + return SafeToolFromCallable.Result.Failure("Tool with name '${tool.name}' failed to deserialize result with error: ${e.message}") } + + return SafeToolFromCallable.Result.Success(result = decodedResult, content = content) +} diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironmentTest.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironmentTest.kt new file mode 100644 index 0000000000..8f8d9e08c3 --- /dev/null +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironmentTest.kt @@ -0,0 +1,65 @@ +package ai.koog.agents.core.environment + +import ai.koog.agents.core.tools.SimpleTool +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.prompt.message.Message.Tool +import ai.koog.prompt.message.ResponseMetaInfo +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class GenericAgentEnvironmentTest { + @Serializable + private data class RequiredArgs(val required: String) + + private class RequiredArgsTool : SimpleTool( + argsSerializer = RequiredArgs.serializer(), + name = "required_args", + description = "Tool that requires a single argument.", + ) { + override suspend fun execute(args: RequiredArgs): String = "Ok" + } + + @Test + fun testInvalidJsonArgsReturnsFailure() = runTest { + val environment = GenericAgentEnvironment( + agentId = "test_agent", + logger = KotlinLogging.logger { }, + toolRegistry = ToolRegistry { tool(RequiredArgsTool()) }, + ) + + val toolCall = Tool.Call( + id = "1", + tool = "required_args", + content = "not-json", + metaInfo = ResponseMetaInfo.Empty, + ) + + val result = environment.executeTool(toolCall) + assertEquals("required_args", result.tool) + assertTrue(result.resultKind is ToolResultKind.Failure) + } + + @Test + fun testMissingFieldReturnsFailure() = runTest { + val environment = GenericAgentEnvironment( + agentId = "test_agent", + logger = KotlinLogging.logger { }, + toolRegistry = ToolRegistry { tool(RequiredArgsTool()) }, + ) + + val toolCall = Tool.Call( + id = "1", + tool = "required_args", + content = "{}", + metaInfo = ResponseMetaInfo.Empty, + ) + + val result = environment.executeTool(toolCall) + assertEquals("required_args", result.tool) + assertTrue(result.resultKind is ToolResultKind.Failure) + } +} diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/SafeToolTest.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/SafeToolTest.kt index 117fc20546..715fffac6f 100644 --- a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/SafeToolTest.kt +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment/SafeToolTest.kt @@ -2,13 +2,18 @@ package ai.koog.agents.core.environment import ai.koog.agents.core.CalculatorChatExecutor.testClock import ai.koog.agents.core.feature.model.toAgentError +import ai.koog.agents.core.tools.Tool import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi import ai.koog.prompt.message.Message import kotlinx.coroutines.test.runTest import kotlinx.serialization.Serializable +import kotlinx.serialization.builtins.serializer import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.encodeToJsonElement import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.put import kotlinx.serialization.serializer import org.junit.jupiter.api.assertThrows import kotlin.test.Test @@ -109,6 +114,15 @@ class SafeToolTest { } } + private object StringEchoTool : Tool( + argsSerializer = String.serializer(), + resultSerializer = String.serializer(), + name = "string_echo", + description = "String echo tool" + ) { + override suspend fun execute(args: String): String = args + } + @Test fun testExecuteSuccess() = runTest { val mockEnvironment = MockEnvironment(shouldSucceed = true) @@ -145,6 +159,26 @@ class SafeToolTest { assertEquals("Raw result content", result) } + @Test + fun testDecodeFailureReturnsFailure() { + val badResult = buildJsonObject { + put("value", "not-a-string-result") + } + + val toolResult = ReceivedToolResult( + id = "1", + tool = StringEchoTool.name, + toolArgs = JsonObject(emptyMap()), + toolDescription = null, + content = "Bad result", + resultKind = ToolResultKind.Success, + result = badResult + ) + + val safeResult = toolResult.toSafeResult(StringEchoTool) + assertTrue(safeResult.isFailure()) + } + @Test fun testResultSuccessHelpers() = runTest { val success = SafeToolFromCallable.Result.Success(TEST_RESULT, "Success content") diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/ToolCallFailureEventsTest.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/ToolCallFailureEventsTest.kt new file mode 100644 index 0000000000..e51a261d2b --- /dev/null +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/ToolCallFailureEventsTest.kt @@ -0,0 +1,176 @@ +package ai.koog.agents.core.feature + +import ai.koog.agents.core.agent.GraphAIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.agent.entity.AIAgentStorageKey +import ai.koog.agents.core.dsl.builder.forwardTo +import ai.koog.agents.core.dsl.builder.strategy +import ai.koog.agents.core.dsl.extension.nodeExecuteTool +import ai.koog.agents.core.environment.ReceivedToolResult +import ai.koog.agents.core.feature.config.FeatureConfig +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext +import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext +import ai.koog.agents.core.feature.pipeline.AIAgentGraphPipeline +import ai.koog.agents.core.tools.SimpleTool +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable +import kotlin.reflect.typeOf +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ToolCallFailureEventsTest { + @Serializable + private data class RequiredArgs(val required: String) + + private class RequiredArgsTool : SimpleTool( + argsSerializer = RequiredArgs.serializer(), + name = "required_args", + description = "Tool that requires a single argument.", + ) { + override suspend fun execute(args: RequiredArgs): String = "Ok" + } + + private class BadResultTool : SimpleTool( + argsSerializer = RequiredArgs.serializer(), + name = "bad_result", + description = "Tool that fails on result serialization.", + ) { + override suspend fun execute(args: RequiredArgs): String = "Ok" + override fun encodeResultToString(result: String): String { + throw IllegalStateException("Serialization failed") + } + } + + private class ToolFailureCaptureConfig : FeatureConfig() { + var onToolCallFailed: (ToolCallFailedContext) -> Unit = {} + var onToolValidationFailed: (ToolValidationFailedContext) -> Unit = {} + } + + private object ToolFailureCaptureFeature : AIAgentGraphFeature { + override val key = AIAgentStorageKey("tool_failure_capture") + override fun createInitialConfig(): ToolFailureCaptureConfig = ToolFailureCaptureConfig() + override fun install(config: ToolFailureCaptureConfig, pipeline: AIAgentGraphPipeline) { + pipeline.interceptToolCallFailed(this) { eventContext -> + config.onToolCallFailed(eventContext) + } + pipeline.interceptToolValidationFailed(this) { eventContext -> + config.onToolValidationFailed(eventContext) + } + } + } + + @Test + fun testInvalidJsonTriggersToolValidationFailedEvent() = runTest { + var toolValidationFailed: ToolValidationFailedContext? = null + + val strategy = strategy("tool_failure_strategy") { + val executeTool by nodeExecuteTool() + edge(nodeStart forwardTo executeTool) + edge(executeTool forwardTo nodeFinish) + } + + val agent = GraphAIAgent( + inputType = typeOf(), + outputType = typeOf(), + promptExecutor = getMockExecutor { }, + agentConfig = AIAgentConfig.withSystemPrompt("test"), + strategy = strategy, + toolRegistry = ToolRegistry { tool(RequiredArgsTool()) }, + installFeatures = { + install(ToolFailureCaptureFeature) { + onToolValidationFailed = { toolValidationFailed = it } + } + } + ) + + val toolCall = Message.Tool.Call( + id = "1", + tool = "required_args", + content = "not-json", + metaInfo = ResponseMetaInfo.Empty, + ) + + agent.run(toolCall) + val capturedFailure = assertNotNull(toolValidationFailed) + assertEquals("required_args", capturedFailure.toolName) + assertTrue(capturedFailure.message.contains("Failed to parse tool arguments")) + } + + @Test + fun testMissingFieldTriggersToolCallFailedEvent() = runTest { + var toolCallFailed: ToolCallFailedContext? = null + + val strategy = strategy("tool_failure_strategy") { + val executeTool by nodeExecuteTool() + edge(nodeStart forwardTo executeTool) + edge(executeTool forwardTo nodeFinish) + } + + val agent = GraphAIAgent( + inputType = typeOf(), + outputType = typeOf(), + promptExecutor = getMockExecutor { }, + agentConfig = AIAgentConfig.withSystemPrompt("test"), + strategy = strategy, + toolRegistry = ToolRegistry { tool(RequiredArgsTool()) }, + installFeatures = { + install(ToolFailureCaptureFeature) { + onToolCallFailed = { toolCallFailed = it } + } + } + ) + + val toolCall = Message.Tool.Call( + id = "1", + tool = "required_args", + content = "{}", + metaInfo = ResponseMetaInfo.Empty, + ) + + agent.run(toolCall) + val captureFailure = assertNotNull(toolCallFailed) + assertEquals("required_args", captureFailure.toolName) + } + + @Test + fun testResultSerializationFailureTriggersToolCallFailedEvent() = runTest { + var toolCallFailed: ToolCallFailedContext? = null + + val strategy = strategy("tool_failure_strategy") { + val executeTool by nodeExecuteTool() + edge(nodeStart forwardTo executeTool) + edge(executeTool forwardTo nodeFinish) + } + + val agent = GraphAIAgent( + inputType = typeOf(), + outputType = typeOf(), + promptExecutor = getMockExecutor { }, + agentConfig = AIAgentConfig.withSystemPrompt("test"), + strategy = strategy, + toolRegistry = ToolRegistry { tool(BadResultTool()) }, + installFeatures = { + install(ToolFailureCaptureFeature) { + onToolCallFailed = { toolCallFailed = it } + } + } + ) + + val toolCall = Message.Tool.Call( + id = "1", + tool = "bad_result", + content = "{\"required\": \"value\"}", + metaInfo = ResponseMetaInfo.Empty, + ) + + agent.run(toolCall) + val capturedFailure = assertNotNull(toolCallFailed) + assertEquals("bad_result", capturedFailure.toolName) + } +} diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithTaskTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithTaskTest.kt index 695c158b8c..45d7d8999f 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithTaskTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithTaskTest.kt @@ -4,28 +4,39 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.GraphAIAgent.FeatureContext import ai.koog.agents.core.agent.ToolCalls import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.agent.entity.ToolSelectionStrategy import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.tools.Tool +import ai.koog.agents.core.tools.ToolDescriptor import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.TestBlankTool import ai.koog.agents.testing.tools.TestFinishTool import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.prompt.dsl.ModerationResult +import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.openai.OpenAIModels import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.llm.LLModel import ai.koog.prompt.llm.OllamaModels import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo import ai.koog.prompt.params.LLMParams +import ai.koog.prompt.streaming.StreamFrame import ai.koog.utils.io.use import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable +import kotlinx.serialization.builtins.serializer import kotlin.js.JsName import kotlin.test.Test import kotlin.test.assertContentEquals import kotlin.test.assertEquals import kotlin.test.assertFails +import kotlin.test.assertTrue class SubgraphWithTaskTest { @@ -693,6 +704,87 @@ class SubgraphWithTaskTest { //endregion Model Without tool_choice Support + //region Invalid Finish Tool Args Recovery Test + + @Serializable + private data class StrictFinishArgs(val value: String) + + private object StrictFinishTool : Tool( + argsSerializer = StrictFinishArgs.serializer(), + resultSerializer = String.serializer(), + name = "strict_finish_tool", + description = "Strict finish tool", + ) { + override suspend fun execute(args: StrictFinishArgs): String = args.value + } + + private class InvalidFinishThenValidExecutor( + private val finishToolName: String, + private val invalidArgsJson: String, + private val validArgsJson: String, + ) : PromptExecutor { + var callCount = 0 + + override suspend fun execute(prompt: Prompt, model: LLModel, tools: List): List { + callCount += 1 + val content = if (callCount == 1) invalidArgsJson else validArgsJson + return listOf( + Message.Tool.Call( + id = callCount.toString(), + tool = finishToolName, + content = content, + metaInfo = ResponseMetaInfo.Empty, + ) + ) + } + + override fun executeStreaming(prompt: Prompt, model: LLModel, tools: List): Flow = + emptyFlow() + + override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult = + ModerationResult(isHarmful = false, categories = emptyMap()) + + override fun close() { } + } + + @Test + @JsName("testSubgraphWithTaskRecoversFromInvalidFinishToolArgs") + fun `test subgraphWithTask recovers fom invalid finish tool args`() = runTest { + val executor = InvalidFinishThenValidExecutor(StrictFinishTool.name, "{}", "{\"value\": \"ok\"}") + + val strategy = strategy("test_strategy") { + val subgraph by subgraphWithTask( + toolSelectionStrategy = ToolSelectionStrategy.ALL, + finishTool = StrictFinishTool, + runMode = ToolCalls.SINGLE_RUN_SEQUENTIAL, + ) { input -> input } + + nodeStart then subgraph then nodeFinish + } + + val agentConfig = AIAgentConfig( + prompt = prompt("test_agent") { + system("You are a test agent.") + }, + model = OpenAIModels.Chat.GPT5, + maxAgentIterations = 50, + ) + + AIAgent( + promptExecutor = executor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry.EMPTY, + ).use { agent -> + val result = agent.run("Test input") + assertEquals("ok", result) + } + + assertTrue(executor.callCount >= 2, "Expected at least 2 LLM calls for recovery") + } + + //endregion + //region Private Methods fun createAgent( diff --git a/agents/agents-features/agents-features-acp/src/jvmMain/kotlin/ai/koog/agents/features/acp/MessageConverters.kt b/agents/agents-features/agents-features-acp/src/jvmMain/kotlin/ai/koog/agents/features/acp/MessageConverters.kt index 5928620760..fb4a7e9fd7 100644 --- a/agents/agents-features/agents-features-acp/src/jvmMain/kotlin/ai/koog/agents/features/acp/MessageConverters.kt +++ b/agents/agents-features/agents-features-acp/src/jvmMain/kotlin/ai/koog/agents/features/acp/MessageConverters.kt @@ -254,7 +254,7 @@ public fun Message.Response.toAcpEvents(tools: List = emptyList( ?: UNKNOWN_TOOL_DESCRIPTION, // TODO: Support kind for tools status = ToolCallStatus.PENDING, - rawInput = response.contentJson, + rawInput = response.contentJsonResult.getOrNull(), ) ) ) diff --git a/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt b/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt index d5bd40666f..eb19d39da6 100644 --- a/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt +++ b/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt @@ -459,7 +459,7 @@ public class OpenTelemetry { } is Message.Tool.Call -> { - ChoiceEvent(provider, message, arguments = message.contentJson) + ChoiceEvent(provider, message, arguments = message.contentJsonResult.getOrNull()) } is Message.Tool.Result -> { @@ -507,7 +507,7 @@ public class OpenTelemetry { } is Message.Tool.Call -> { - add(ChoiceEvent(provider, message, arguments = message.contentJson, index = index)) + add(ChoiceEvent(provider, message, arguments = message.contentJsonResult.getOrNull(), index = index)) } } } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt index 0b5d4bc978..c6fdfa4649 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt @@ -23,6 +23,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlinx.serialization.json.JsonElement +import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KType import kotlin.time.ExperimentalTime import kotlin.uuid.ExperimentalUuidApi @@ -368,7 +369,13 @@ public class Persistence( .reversed() .forEach { toolCall -> rollbackToolRegistry.getRollbackTool(toolCall.tool)?.let { rollbackTool -> - val toolArgs = rollbackTool.decodeArgs(toolCall.contentJson) + val toolArgs = try { + toolCall.contentJsonResult.getOrNull()?.let { rollbackTool.decodeArgs(it) } + } catch (e: CancellationException) { + throw e + } catch (_: Exception) { + null + } rollbackTool.executeUnsafe(toolArgs) } diff --git a/agents/agents-mcp-server/src/commonMain/kotlin/ai/koog/agents/mcp/server/McpServer.kt b/agents/agents-mcp-server/src/commonMain/kotlin/ai/koog/agents/mcp/server/McpServer.kt index e9f39d7276..222da920bb 100644 --- a/agents/agents-mcp-server/src/commonMain/kotlin/ai/koog/agents/mcp/server/McpServer.kt +++ b/agents/agents-mcp-server/src/commonMain/kotlin/ai/koog/agents/mcp/server/McpServer.kt @@ -28,6 +28,7 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import kotlinx.serialization.json.putJsonArray import kotlinx.serialization.json.putJsonObject +import kotlin.coroutines.cancellation.CancellationException import io.modelcontextprotocol.kotlin.sdk.types.Tool as SdkTool /** @@ -123,7 +124,16 @@ public fun Server.addTool( tool: Tool<*, *>, ) { addTool(tool.descriptor.asSdkTool()) { request -> - val args = tool.decodeArgs(request.arguments ?: EmptyJsonObject) + val args = try { + tool.decodeArgs(request.arguments ?: EmptyJsonObject) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + return@addTool CallToolResult( + content = listOf(TextContent("Failed to parse arguments for tool '${tool.name}': ${e.message}")), + isError = true, + ) + } val result = tool.executeUnsafe(args) CallToolResult( diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt index 8eaeffa17f..d0107812d9 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt @@ -10,6 +10,7 @@ import ai.koog.prompt.message.ResponseMetaInfo import ai.koog.prompt.streaming.StreamFrame import kotlinx.datetime.Clock import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.buildJsonObject import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -66,7 +67,7 @@ internal object BedrockAmazonNovaSerialization { toolUse = NovaToolUse( toolUseId = msg.id ?: Uuid.random().toString(), name = msg.tool, - input = msg.contentJson, + input = msg.contentJsonResult.getOrElse { JsonObject(emptyMap()) }, ) ) ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt index 843ceabbfb..f1dd972e40 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt @@ -24,6 +24,7 @@ import kotlinx.coroutines.flow.Flow import kotlinx.datetime.Clock import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonNamingStrategy +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.buildJsonArray import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.encodeToJsonElement @@ -106,7 +107,7 @@ internal object BedrockAnthropicClaudeSerialization { BedrockAnthropicInvokeModelContent.ToolCall( msg.id!!, msg.tool, - json.decodeFromString(msg.content) + msg.contentJsonResult.getOrElse { JsonObject(emptyMap()) } ) ) ) diff --git a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt index 1bc400602b..b2058b3f98 100644 --- a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt +++ b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt @@ -238,12 +238,19 @@ public sealed interface Message { this(id, tool, ContentPart.Text(content), metaInfo) /** - * Lazily parses the content of the tool call as a JSON object. + * Lazily parses and caches the result of parsing [content] as a JSON object. */ - val contentJson: JsonObject by lazy { - Json.parseToJsonElement(content).jsonObject + val contentJsonResult: kotlin.Result by lazy { + runCatching { Json.parseToJsonElement(content).jsonObject } } + /** + * Lazily parses the content of the tool call as a JSON object. + * Can throw an exception when parsing fails. + */ + val contentJson: JsonObject + get() = contentJsonResult.getOrThrow() + override fun copy(updatedMetaInfo: ResponseMetaInfo): Call = this.copy(metaInfo = updatedMetaInfo) } diff --git a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/streaming/StreamFrame.kt b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/streaming/StreamFrame.kt index 3f6b19abc3..8d44126080 100644 --- a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/streaming/StreamFrame.kt +++ b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/streaming/StreamFrame.kt @@ -37,11 +37,18 @@ public sealed interface StreamFrame { ) : StreamFrame { /** - * Lazily parses the content of the tool call as a JSON object. + * Lazily parses and caches the result of parsing [content] as a JSON object. */ - val contentJson: JsonObject by lazy { - Json.parseToJsonElement(content).jsonObject + val contentJsonResult: Result by lazy { + runCatching { Json.parseToJsonElement(content).jsonObject } } + + /** + * Lazily parses the content of the tool call as a JSON object. + * Can throw an exception when parsing fails. + */ + val contentJson: JsonObject + get() = contentJsonResult.getOrThrow() } /**