Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public open class AIAgentSubgraph<TInput, TOutput>(
tools = newTools,
model = llmModel ?: context.llm.model,
prompt = context.llm.prompt.copy(params = llmParams ?: context.llm.prompt.params),
responseProcessor = responseProcessor
responseProcessor = responseProcessor ?: context.llm.responseProcessor,
),
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import ai.koog.prompt.message.LLMChoice
import ai.koog.prompt.message.Message
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.processor.ResponseProcessor
import ai.koog.prompt.processor.executeProcessed
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.structure.StructureFixingParser
import ai.koog.prompt.structure.StructuredRequestConfig
Expand Down Expand Up @@ -64,7 +65,7 @@ public open class AIAgentLLMSession(
@InternalAgentsApi
public open override suspend fun executeMultiple(prompt: Prompt, tools: List<ToolDescriptor>): List<Message.Response> {
val preparedPrompt = preparePrompt(prompt, tools)
return executor.execute(preparedPrompt, model, tools)
return executor.executeProcessed(preparedPrompt, model, tools, responseProcessor)
}

@InternalAgentsApi
Expand Down Expand Up @@ -94,29 +95,31 @@ public open class AIAgentLLMSession(
return executeMultiple(promptWithDisabledTools, emptyList()).first { it !is Message.Reasoning }
}

private fun preparePromptWithToolChoice(toolChoice: LLMParams.ToolChoice) =
prompt.withUpdatedParams {
this.toolChoice = toolChoice
}

public open override suspend fun requestLLMOnlyCallingTools(): Message.Response {
validateSession()
// We use the multiple-response method to ensure we capture all context (e.g. thinking)
// even though we only return the specific tool call.
val responses = requestLLMMultipleOnlyCallingTools()
val promptWithOnlyCallingTools = preparePromptWithToolChoice(LLMParams.ToolChoice.Required)
val responses = executeMultiple(promptWithOnlyCallingTools, tools)
return responses.firstOrNull { it is Message.Tool.Call }
?: error("requestLLMOnlyCallingTools expected at least one Tool.Call but received: ${responses.map { it::class.simpleName }}")
?: responses.first { it !is Message.Reasoning }
}

public open override suspend fun requestLLMMultipleOnlyCallingTools(): List<Message.Response> {
validateSession()
val promptWithOnlyCallingTools = prompt.withUpdatedParams {
toolChoice = LLMParams.ToolChoice.Required
}
val promptWithOnlyCallingTools = preparePromptWithToolChoice(LLMParams.ToolChoice.Required)
return executeMultiple(promptWithOnlyCallingTools, tools)
}

public open override suspend fun requestLLMForceOneTool(tool: ToolDescriptor): Message.Response {
validateSession()
check(tools.contains(tool)) { "Unable to force call to tool `${tool.name}` because it is not defined" }
val promptWithForcingOneTool = prompt.withUpdatedParams {
toolChoice = LLMParams.ToolChoice.Named(tool.name)
}
val promptWithForcingOneTool = preparePromptWithToolChoice(LLMParams.ToolChoice.Named(tool.name))
return executeSingle(promptWithForcingOneTool, tools)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ internal class AIAgentLLMWriteSessionImpl internal constructor(
return super<AIAgentLLMSession>.requestLLMWithoutTools().also { response -> appendPrompt { message(response) } }
}

override suspend fun requestLLMOnlyCallingTools(): Message.Response {
return super<AIAgentLLMSession>.requestLLMOnlyCallingTools()
.also { response -> appendPrompt { message(response) } }
}

override suspend fun requestLLMMultipleOnlyCallingTools(): List<Message.Response> {
return super<AIAgentLLMSession>.requestLLMMultipleOnlyCallingTools()
.also { responses ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public inline fun <reified T> AIAgentSubgraphBuilderBase<*, *>.nodeUpdatePrompt(
* @param name Optional name for the node.
*/
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools(
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestOnlyCallingTools(
name: String? = null
): AIAgentNodeDelegate<String, Message.Response> =
node(name) { message ->
Expand All @@ -95,14 +95,48 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools(
}
}

/**
* A node that appends a user message to the LLM prompt and gets a response where the LLM can only call tools.
*
* @param name Optional name for the node.
*/
@Deprecated(
"Please use nodeLLMRequestOnlyCallingTools instead.",
ReplaceWith("nodeLLMRequestOnlyCallingTools(name)")
)
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools(
name: String? = null
): AIAgentNodeDelegate<String, Message.Response> =
nodeLLMRequestOnlyCallingTools(name)

/**
* A node that appends a user message to the LLM prompt and gets multiple LLM responses where the LLM can only call tools.
*
* @param name Optional name for the node.
*/
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestMultipleOnlyCallingTools(
name: String? = null
): AIAgentNodeDelegate<String, List<Message.Response>> =
node(name) { message ->
llm.writeSession {
appendPrompt {
user(message)
}

requestLLMMultipleOnlyCallingTools()
}
}

/**
* A node that that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
*
* @param name Optional node name.
* @param tool Tool descriptor the LLM is required to use.
*/
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestForceOneTool(
name: String? = null,
tool: ToolDescriptor
): AIAgentNodeDelegate<String, Message.Response> =
Expand All @@ -116,18 +150,52 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
}
}

/**
* A node that that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
*
* @param name Optional node name.
* @param tool Tool descriptor the LLM is required to use.
*/
@Deprecated(
"Please use nodeLLMRequestForceOneTool instead.",
ReplaceWith("nodeLLMRequestForceOneTool(name, tool)")
)
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
name: String? = null,
tool: ToolDescriptor
): AIAgentNodeDelegate<String, Message.Response> =
nodeLLMRequestForceOneTool(name, tool)

/**
* A node that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
*
* @param name Optional node name.
* @param tool Tool the LLM is required to use.
*/
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestForceOneTool(
name: String? = null,
tool: Tool<*, *>
): AIAgentNodeDelegate<String, Message.Response> =
nodeLLMRequestForceOneTool(name, tool.descriptor)

/**
* A node that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
*
* @param name Optional node name.
* @param tool Tool the LLM is required to use.
*/
@Deprecated(
"Please use nodeLLMRequestForceOneTool instead.",
ReplaceWith("nodeLLMRequestForceOneTool(name, tool)")
)
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
name: String? = null,
tool: Tool<*, *>
): AIAgentNodeDelegate<String, Message.Response> =
nodeLLMSendMessageForceOneTool(name, tool.descriptor)
nodeLLMRequestForceOneTool(name, tool)

/**
* A node that appends a user message to the LLM prompt and gets a response with optional tool usage.
Expand Down Expand Up @@ -407,6 +475,27 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendToolResult(
}
}

/**
* A node that adds a tool result to the prompt and gets an LLM response where the LLM can only call tools.
*
* @param name Optional node name.
*/
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendToolResultOnlyCallingTools(
name: String? = null
): AIAgentNodeDelegate<List<ReceivedToolResult>, Message.Response> =
node(name) { results ->
llm.writeSession {
appendPrompt {
tool {
results.forEach { result(it) }
}
}

requestLLMOnlyCallingTools()
}
}

/**
* A node that executes multiple tool calls. These calls can optionally be executed in parallel.
*
Expand Down Expand Up @@ -481,6 +570,27 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMultipleToolResults(
}
}

/**
* A node that adds multiple tool results to the prompt and gets multiple LLM responses where the LLM can only call tools.
*
* @param name Optional node name.
*/
@AIAgentBuilderDslMarker
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMultipleToolResultsOnlyCallingTools(
name: String? = null
): AIAgentNodeDelegate<List<ReceivedToolResult>, List<Message.Response>> =
node(name) { results ->
llm.writeSession {
appendPrompt {
tool {
results.forEach { result(it) }
}
}

requestLLMMultipleOnlyCallingTools()
}
}

/**
* A node that calls a specific tool directly using the provided arguments.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphDelegate
import ai.koog.agents.core.dsl.builder.forwardTo
import ai.koog.agents.core.dsl.extension.nodeLLMRequest
import ai.koog.agents.core.dsl.extension.nodeLLMRequestMultiple
import ai.koog.agents.core.dsl.extension.nodeLLMSendMultipleToolResults
import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult
import ai.koog.agents.core.dsl.extension.setToolChoiceRequired
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.ToolResultKind
Expand Down Expand Up @@ -491,7 +493,12 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
// Helper node to overcome problems of the current api and repeat less code when writing routing conditions
val nodeDecide by node<List<Message.Response>, List<Message.Response>> { it }

val nodeCallLLM by nodeLLMRequestMultiple()
val nodeCallLLMDelegate = if (runMode == ToolCalls.SINGLE_RUN_SEQUENTIAL) {
nodeLLMRequest().transform { listOf(it) }
} else {
nodeLLMRequestMultiple()
}
val nodeCallLLM by nodeCallLLMDelegate

val callToolsHacked by node<List<Message.Tool.Call>, List<ReceivedToolResult>> { toolCalls ->
val (finishToolCalls, regularToolCalls) = toolCalls.partition { it.tool == finishTool.name }
Expand Down Expand Up @@ -521,8 +528,6 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
}
}

val sendToolsResults by nodeLLMSendMultipleToolResults()

@OptIn(DetachedPromptExecutorAPI::class)
val handleAssistantMessage by node<Message.Assistant, List<Message.Response>> { response ->
if (llm.model.capabilities.contains(LLMCapability.ToolChoice)) {
Expand Down Expand Up @@ -590,9 +595,14 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
transformed { toolsResults -> toolsResults.first() }
)

edge(callToolsHacked forwardTo sendToolsResults)

edge(sendToolsResults forwardTo nodeDecide)
if (runMode == ToolCalls.SINGLE_RUN_SEQUENTIAL) {
val sendToolResult by nodeLLMSendToolResult()
edge(callToolsHacked forwardTo sendToolResult transformed { it.first() })
edge(sendToolResult forwardTo nodeDecide transformed { listOf(it) })
} else {
val sendToolsResults by nodeLLMSendMultipleToolResults()
callToolsHacked then sendToolsResults then nodeDecide
}

edge(finalizeTask forwardTo nodeFinish)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.processor.ResponseProcessor
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.Serializable
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertContains
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -392,6 +395,8 @@ class AIAgentLLMWriteSessionTest {
}

@Test
// This behavior is not supported yet.
@Ignore
fun testRequestLLMOnlyCallingToolsWithThinking() = runTest {
val thinkingContent = "<thinking>Checking file...</thinking>"
val testTool = TestTool()
Expand All @@ -417,30 +422,6 @@ class AIAgentLLMWriteSessionTest {
assertEquals("test-tool", (lastTwoMessages[1] as Message.Tool.Call).tool)
}

@Test
fun testRequestLLMOnlyCallingToolsNoToolCallThrowsException() = runTest {
val mockExecutor = getMockExecutor(clock = testClock) {
// Simulate model refusing to use tools and just responding with text
mockLLMAnswer("I cannot use tools for this request.").asDefaultResponse
}

val session = createSession(mockExecutor, listOf(TestTool()))

val exception = kotlin.runCatching {
session.requestLLMOnlyCallingTools()
}.exceptionOrNull()

assertNotNull(exception, "Expected an exception when no tool call is found")
assertTrue(
exception is IllegalStateException,
"Expected IllegalStateException but got ${exception::class.simpleName}"
)
assertTrue(
exception.message?.contains("expected at least one Tool.Call") == true,
"Exception message should indicate missing tool call"
)
}

@Test
fun testRequestLLMOnlyCallingToolsWithMultipleToolCalls() = runTest {
val testTool = TestTool()
Expand All @@ -464,8 +445,9 @@ class AIAgentLLMWriteSessionTest {
assertTrue(response is Message.Tool.Call, "Expected response to be a Tool Call")
assertEquals("test-tool", response.tool)

// Both tool calls should be in history
val lastTwoMessages = session.prompt.messages.takeLast(2)
assertTrue(lastTwoMessages.all { it is Message.Tool.Call })
// Only the first tool call should be added to the history
val lastMessage = session.prompt.messages.last()
assertIs<Message.Tool.Call>(lastMessage)
assertContains(lastMessage.content, "first")
}
}
Loading