diff --git a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt index 79f5705280..cfca6e2d14 100644 --- a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt +++ b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt @@ -17,7 +17,6 @@ import ai.koog.agents.core.dsl.extension.nodeLLMSendMultipleToolResults import ai.koog.agents.core.dsl.extension.setToolChoiceRequired import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.environment.ToolResultKind -import ai.koog.agents.core.environment.executeTools import ai.koog.agents.core.environment.toSafeResult import ai.koog.agents.core.tools.Tool import ai.koog.agents.core.tools.ToolDescriptor @@ -99,6 +98,17 @@ public object SubgraphWithTaskUtils { * to prevent redundancy in responses and ensure conciseness in communication. */ public const val ASSISTANT_RESPONSE_REPEAT_MAX: Int = 3 + + /** + * A message shown to the model when it does not return a tool call during the subgraphWithTask execution. + * + * The message clarifies to the model that a tool call is required here, + * And if the task is finished, the finish tool has to be called. + */ + public fun messageOnAssistantResponse(finishToolName: String): String = markdown { + h1("DO NOT CHAT WITH ME DIRECTLY! CALL TOOLS, INSTEAD.") + h2("IF YOU HAVE FINISHED, CALL `$finishToolName` TOOL!") + } } /** @@ -478,36 +488,18 @@ public inline fun AIA val nodeCallLLM by nodeLLMRequestMultiple() val callToolsHacked by node, List> { toolCalls -> - val (finishToolCalls, regularToolCalls) = toolCalls.partition { it.tool == finishTool.name } - - // Execute finish tool - val finishToolResult = finishToolCalls.firstOrNull()?.let { toolCall -> - executeFinishTool(toolCall, finishTool) - } - - // Execute regular tools - val regularToolsResults = when (runMode) { - ToolCalls.PARALLEL -> { - environment.executeTools(regularToolCalls) - } - ToolCalls.SEQUENTIAL, - ToolCalls.SINGLE_RUN_SEQUENTIAL -> { - regularToolCalls.map { toolCall -> - environment.executeTool(toolCall) - } - } - } - - buildList { - finishToolResult?.let { add(it) } - addAll(regularToolsResults) - } + // use a method for the subtask to avoid code duplication + executeMultipleToolsHacked( + toolCalls, + finishTool, + runMode == ToolCalls.PARALLEL + ) } val sendToolsResults by nodeLLMSendMultipleToolResults() @OptIn(DetachedPromptExecutorAPI::class) - val handleAssistantMessage by node> { response -> + val handleAssistantMessage by node { response -> if (llm.model.capabilities.contains(LLMCapability.ToolChoice)) { error( "Subgraph with task must always call tools, but no ${Message.Tool.Call::class.simpleName} was generated, " + @@ -526,19 +518,7 @@ public inline fun AIA ) } - llm.writeSession { - // append a new message to the history with feedback: - appendPrompt { - user { - markdown { - h1("DO NOT CHAT WITH ME DIRECTLY! CALL TOOLS, INSTEAD.") - h2("IF YOU HAVE FINISHED, CALL `${finishTool.name}` TOOL!") - } - } - } - - requestLLMMultiple() - } + SubgraphWithTaskUtils.messageOnAssistantResponse(finishTool.name) } nodeStart then setupTask then nodeCallLLM then nodeDecide @@ -555,7 +535,7 @@ public inline fun AIA transformed { responses -> responses.first() as Message.Assistant } ) - edge(handleAssistantMessage forwardTo nodeDecide) + edge(handleAssistantMessage forwardTo nodeCallLLM) // throw to terminate the agent early with exception edge( diff --git a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubtaskExt.kt b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubtaskExt.kt index 4927920ea1..b59fe4204e 100644 --- a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubtaskExt.kt +++ b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubtaskExt.kt @@ -1,6 +1,7 @@ package ai.koog.agents.ext.agent import ai.koog.agents.core.agent.ToolCalls +import ai.koog.agents.core.agent.context.AIAgentContext import ai.koog.agents.core.agent.context.AIAgentFunctionalContext import ai.koog.agents.core.agent.context.DetachedPromptExecutorAPI import ai.koog.agents.core.annotation.InternalAgentsApi @@ -13,6 +14,7 @@ import ai.koog.agents.core.dsl.extension.sendToolResult import ai.koog.agents.core.dsl.extension.setToolChoiceRequired import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.environment.executeTools +import ai.koog.agents.core.environment.result import ai.koog.agents.core.environment.toSafeResult import ai.koog.agents.core.tools.Tool import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi @@ -208,12 +210,13 @@ internal suspend inline fun AIAgentF val toolCalls = extractToolCalls(responses) val toolResults = executeMultipleToolsHacked(toolCalls, finishTool, parallelTools = runMode == ToolCalls.PARALLEL) - responses = sendMultipleToolResults(toolResults) toolResults.firstOrNull { it.tool == finishTool.descriptor.name } ?.let { finishResult -> return finishResult.toSafeResult(finishTool).asSuccessful().result } + + responses = sendMultipleToolResults(toolResults) } else -> { @@ -249,11 +252,12 @@ internal suspend inline fun AIAgentF when { response is Message.Tool.Call -> { val toolResult = executeToolHacked(response, finishTool) - response = sendToolResult(toolResult) if (toolResult.tool == finishTool.descriptor.name) { return toolResult.toSafeResult(finishTool).asSuccessful().result } + + response = sendToolResult(toolResult) } else -> { @@ -266,10 +270,7 @@ internal suspend inline fun AIAgentF } response = requestLLM( - message = markdown { - h1("DO NOT CHAT WITH ME DIRECTLY! CALL TOOLS, INSTEAD.") - h2("IF YOU HAVE FINISHED, CALL `${finishTool.name}` TOOL!") - } + SubgraphWithTaskUtils.messageOnAssistantResponse(finishTool.name) ) } } @@ -278,17 +279,12 @@ internal suspend inline fun AIAgentF @OptIn(InternalAgentToolsApi::class, InternalAgentsApi::class) @PublishedApi -internal suspend inline fun AIAgentFunctionalContext.executeMultipleToolsHacked( +internal suspend inline fun AIAgentContext.executeMultipleToolsHacked( toolCalls: List, finishTool: Tool, parallelTools: Boolean = false ): List { - val finishTools = toolCalls.filter { it.tool == finishTool.descriptor.name } - val normalTools = toolCalls.filterNot { it.tool == finishTool.descriptor.name } - - val finishToolResults = finishTools.map { toolCall -> - executeFinishTool(toolCall, finishTool) - } + val (finishTools, normalTools) = toolCalls.partition { it.tool == finishTool.name } val normalToolResults = if (parallelTools) { environment.executeTools(normalTools) @@ -296,7 +292,24 @@ internal suspend inline fun AIAgentF normalTools.map { environment.executeTool(it) } } - return finishToolResults + normalToolResults + // if a finish tool was called, the subtask execution will be finished, + // and the normal tool results have to be appended to the prompt here, + // otherwise they will be lost + if (finishTools.isNotEmpty()) { + llm.writeSession { + appendPrompt { + tool { + normalToolResults.forEach { result(it) } + } + } + } + } + + val finishToolResults = finishTools.map { toolCall -> + executeFinishTool(toolCall, finishTool) + } + + return normalToolResults + finishToolResults } @OptIn(InternalAgentToolsApi::class) diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt index 7833cfdad1..f0a556b5d9 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.ToolCalls import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.agent.execution.path +import ai.koog.agents.core.agent.functionalStrategy import ai.koog.agents.core.agent.singleRunStrategy import ai.koog.agents.core.dsl.builder.ParallelNodeExecutionResult import ai.koog.agents.core.dsl.builder.forwardTo @@ -15,8 +16,11 @@ import ai.koog.agents.core.dsl.extension.nodeLLMRequest import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult import ai.koog.agents.core.dsl.extension.onAssistantMessage import ai.koog.agents.core.dsl.extension.onToolCall +import ai.koog.agents.core.dsl.extension.requestLLM +import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.agent.reActStrategy +import ai.koog.agents.ext.agent.subtask import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.features.eventHandler.feature.EventHandlerConfig import ai.koog.agents.snapshot.feature.Persistence @@ -26,6 +30,7 @@ import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider import ai.koog.integration.tests.utils.Models import ai.koog.integration.tests.utils.RetryUtils.withRetry import ai.koog.integration.tests.utils.tools.CalculateSumTool +import ai.koog.integration.tests.utils.tools.CalculatorTool import ai.koog.integration.tests.utils.tools.CalculatorToolNoArgs import ai.koog.integration.tests.utils.tools.DelayTool import ai.koog.integration.tests.utils.tools.GetTransactionsTool @@ -111,6 +116,13 @@ class AIAgentIntegrationTest : AIAgentTestBase() { Arguments.of(HistoryCompressionStrategy.Chunked(2), "Chunked(2)") ) } + + @JvmStatic + fun runModes(): Stream = Stream.of( + ToolCalls.SEQUENTIAL, + ToolCalls.PARALLEL, + ToolCalls.SINGLE_RUN_SEQUENTIAL, + ) } val twoToolsRegistry = ToolRegistry { @@ -171,7 +183,7 @@ class AIAgentIntegrationTest : AIAgentTestBase() { name = "compress_history", strategy = strategy ) - val compressToolResult by nodeLLMCompressHistory( + val compressToolResult by nodeLLMCompressHistory( name = "compress_history", strategy = strategy ) @@ -1074,7 +1086,7 @@ class AIAgentIntegrationTest : AIAgentTestBase() { agent.run("Hi") with(state) { - errors.shouldBeEmpty() // There should be no errors during parallel execution} + errors.shouldBeEmpty() // There should be no errors during parallel execution results.shouldNotBeEmpty().first() as String should { contain("Math result: 56") contain("Text result: Hello World") @@ -1302,4 +1314,38 @@ class AIAgentIntegrationTest : AIAgentTestBase() { } } } + + @ParameterizedTest + @MethodSource("runModes") + fun integration_testSubtaskCorrectlySavesToolMessages(runMode: ToolCalls) = runTest(timeout = 3600.seconds) { + withRetry { + val model = OpenAIModels.Chat.GPT4o + val executor = getExecutor(model) + val toolRegistry = ToolRegistry { + tool(CalculatorTool) + } + + val strategy = functionalStrategy("subtask-test") { input -> + subtask(input, runMode = runMode) { it } + requestLLM("What's the result?").content + } + + val agent = AIAgent( + strategy = strategy, + promptExecutor = executor, + agentConfig = AIAgentConfig( + prompt = prompt("subtask-test") { + system("You are a helpful assistant specialized in simple calculations.") + }, + model = model, + maxAgentIterations = 10 + ), + toolRegistry = toolRegistry + ) + + val result = agent.run("2 * 7") + + result.shouldContain("14") + } + } }