Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ public interface ToolCallDescriber {
buildJsonObject {
message.id?.let { put("tool_call_id", JsonPrimitive(it)) }
put("tool_name", JsonPrimitive(message.tool))
put("tool_args", message.contentJson)
message.contentJsonResult.fold(
onSuccess = { put("tool_args", it) },
onFailure = { put("tool_args_error", JsonPrimitive("Failed to parse tool arguments: $it")) }
)
}
),
metaInfo = message.metaInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import ai.koog.agents.core.environment.toSafeResult
import ai.koog.agents.core.tools.Tool
import ai.koog.prompt.message.ContentPart
import ai.koog.prompt.message.Message
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass

/**
Expand Down Expand Up @@ -98,7 +99,13 @@ public inline fun <IncomingOutput, IntermediateOutput, OutgoingInput, reified Ar
return onIsInstance(Message.Tool.Call::class)
.onCondition { it.tool == tool.name }
.onCondition { toolCall ->
val args = tool.decodeArgs(toolCall.contentJson)
val args = try {
tool.decodeArgs(toolCall.contentJsonResult.getOrNull() ?: return@onCondition false)
} catch (e: CancellationException) {
throw e
} catch (_: Exception) {
return@onCondition false
}
block(args)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package ai.koog.agents.core.environment

import ai.koog.agents.core.agent.context.AIAgentContext
import ai.koog.agents.core.agent.execution.AgentExecutionInfo
import ai.koog.agents.core.feature.model.toAgentError
import ai.koog.prompt.message.Message
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.serialization.json.JsonObject
import kotlin.coroutines.cancellation.CancellationException
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid

Expand All @@ -20,14 +23,48 @@ public class ContextualAgentEnvironment(
override suspend fun executeTool(toolCall: Message.Tool.Call): ReceivedToolResult {
@OptIn(ExperimentalUuidApi::class)
val eventId = Uuid.random().toString()
val toolDescription = context.llm.toolRegistry.getToolOrNull(toolCall.tool)?.descriptor?.description

val toolArgs = try {
toolCall.contentJson
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
logger.error { "Failed to execute tool call with id '${toolCall.id}' while parsing args: ${e.message}" }

val tool = toolCall.tool
val toolArgs = JsonObject(emptyMap())
val message = "Failed to parse tool arguments: ${e.message}"
context.pipeline.onToolValidationFailed(
eventId = eventId,
executionInfo = context.executionInfo,
runId = context.runId,
toolCallId = toolCall.id,
toolName = tool,
toolDescription = toolDescription,
toolArgs = toolArgs,
message = message,
error = e.toAgentError(),
context = context
)
return ReceivedToolResult(
id = toolCall.id,
tool = tool,
toolArgs = toolArgs,
toolDescription = null,
content = message,
resultKind = ToolResultKind.ValidationError(e.toAgentError()),
result = null
)
}

logger.trace {
"Executing tool call (" +
"event id: $eventId, " +
"run id: ${context.runId}, " +
"tool call id: ${toolCall.id}, " +
"tool: ${toolCall.tool}, " +
"args: ${toolCall.contentJson})"
"args: $toolArgs)"
}

context.pipeline.onToolCallStarting(
Expand All @@ -36,8 +73,8 @@ public class ContextualAgentEnvironment(
runId = context.runId,
toolCallId = toolCall.id,
toolName = toolCall.tool,
toolDescription = context.llm.toolRegistry.getToolOrNull(toolCall.tool)?.descriptor?.description,
toolArgs = toolCall.contentJson,
toolDescription = toolDescription,
toolArgs = toolArgs,
context = context
)

Expand All @@ -52,7 +89,7 @@ public class ContextualAgentEnvironment(
"tool call id: ${toolCall.id}, " +
"tool: ${toolCall.tool}, " +
"tool description: ${toolResult.toolDescription}, " +
"args: ${toolCall.contentJson}) " +
"args: $toolArgs) " +
"with result: $toolResult"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import ai.koog.agents.core.tools.ToolRegistry
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
import ai.koog.prompt.message.Message
import io.github.oshai.kotlinlogging.KLogger
import kotlinx.serialization.json.JsonObject
import kotlin.coroutines.cancellation.CancellationException

/**
* Represents base agent environment with generic abstractions.
Expand All @@ -19,7 +21,7 @@ public class GenericAgentEnvironment(

override suspend fun executeTool(toolCall: Message.Tool.Call): ReceivedToolResult {
logger.info {
formatLog("Executing tool (name: ${toolCall.tool}, args: ${toolCall.contentJson})")
formatLog("Executing tool (name: ${toolCall.tool}, args: ${toolCall.contentJsonResult.getOrElse { "Failed to parse tool arguments: ${it.message}" }})")
}

val environmentToolResult = processToolCall(toolCall)
Expand All @@ -45,7 +47,21 @@ public class GenericAgentEnvironment(
// Tool
val id = toolCall.id
val toolName = toolCall.tool
val toolArgsJson = toolCall.contentJson
val toolArgsJson = try {
toolCall.contentJson
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
return ReceivedToolResult(
id = id,
tool = toolName,
toolArgs = JsonObject(emptyMap()),
toolDescription = null,
content = "Tool with name '$toolName' failed to parse arguments due to the error: ${e.message}",
resultKind = ToolResultKind.Failure(e.toAgentError()),
result = null,
)
}

val tool = toolRegistry.getToolOrNull(toolName)
?: run {
Expand All @@ -66,6 +82,8 @@ public class GenericAgentEnvironment(
// Tool Args
val toolArgs = try {
tool.decodeArgs(toolArgsJson)
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
logger.error(e) { formatLog("Tool with name '$toolName' failed to parse arguments: $toolArgsJson") }
return ReceivedToolResult(
Expand All @@ -82,6 +100,8 @@ public class GenericAgentEnvironment(
val toolResult = try {
@Suppress("UNCHECKED_CAST")
(tool as Tool<Any?, Any?>).execute(toolArgs)
} catch (e: CancellationException) {
throw e
} catch (e: ToolException) {
return ReceivedToolResult(
id = id,
Expand All @@ -108,14 +128,31 @@ public class GenericAgentEnvironment(

logger.trace { "Completed execution of the tool '$toolName' with result: $toolResult" }

val (content, result) = try {
tool.encodeResultToStringUnsafe(toolResult) to tool.encodeResult(toolResult)
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
logger.error(e) { "Tool with name '$toolName' failed to encode result: $toolResult" }
return ReceivedToolResult(
id = id,
tool = toolName,
toolArgs = toolArgsJson,
toolDescription = toolDescription,
content = "Tool with name '$toolName' failed to serialize result due to the error: ${e.message}!",
resultKind = ToolResultKind.Failure(e.toAgentError()),
result = null
)
}

return ReceivedToolResult(
id = id,
tool = toolName,
toolArgs = toolArgsJson,
toolDescription = toolDescription,
content = tool.encodeResultToStringUnsafe(toolResult),
content = content,
resultKind = ToolResultKind.Success,
result = tool.encodeResult(toolResult)
result = result
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import ai.koog.agents.core.tools.Tool
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.ResponseMetaInfo
import kotlinx.datetime.Clock
import kotlin.coroutines.cancellation.CancellationException

/**
* A wrapper class designed to safely execute a tool within a given AI agent environment.
Expand Down Expand Up @@ -190,11 +191,15 @@ public data class SafeTool<TArgs, TResult>(
* @return A [SafeTool.Result] which will either be a [SafeTool.Result.Failure] or [SafeTool.Result.Success]
* based on the presence and validity of the `result` in the [ReceivedToolResult].
*/
public fun <TResult> ReceivedToolResult.toSafeResult(tool: Tool<*, TResult>): SafeTool.Result<TResult> = when (result) {
null -> {
SafeTool.Result.Failure(message = content)
}
else -> {
SafeTool.Result.Success(result = tool.decodeResult(this.result), content = content)
public fun <TResult> ReceivedToolResult.toSafeResult(tool: Tool<*, TResult>): SafeTool.Result<TResult> {
val encodedResult = result ?: return SafeTool.Result.Failure(message = content)
val decodedResult = try {
tool.decodeResult(encodedResult)
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
return SafeTool.Result.Failure("Tool with name '${tool.name}' failed to deserialize result with error: ${e.message}")
}

return SafeTool.Result.Success(result = decodedResult, content = content)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import ai.koog.agents.core.dsl.extension.setToolChoiceRequired
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.ToolResultKind
import ai.koog.agents.core.environment.toSafeResult
import ai.koog.agents.core.feature.model.toAgentError
import ai.koog.agents.core.tools.Tool
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
Expand All @@ -28,7 +29,9 @@ import ai.koog.prompt.message.Message
import ai.koog.prompt.params.LLMParams
import ai.koog.prompt.processor.ResponseProcessor
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.serializer
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass

/**
Expand Down Expand Up @@ -586,7 +589,7 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA

edge(
callToolsHacked forwardTo finalizeTask
onCondition { toolResults -> toolResults.firstOrNull()?.let { it.tool == finishTool.name } == true }
onCondition { toolResults -> toolResults.firstOrNull()?.let { it.tool == finishTool.name && it.resultKind is ToolResultKind.Success } == true }
transformed { toolsResults -> toolsResults.first() }
)

Expand All @@ -603,11 +606,26 @@ internal suspend fun <Output, OutputTransformed> AIAgentContext.executeFinishToo
toolCall: Message.Tool.Call,
finishTool: Tool<Output, OutputTransformed>,
): ReceivedToolResult {
val toolDescription = finishTool.descriptor.description
// Execute Finish tool directly and get a result
val args = finishTool.decodeArgs(toolCall.contentJson)
val toolResult = finishTool.execute(args = args)
val encodedResult = try {
val args = finishTool.decodeArgs(toolCall.contentJson)
val toolResult = finishTool.execute(args = args)
finishTool.encodeResult(toolResult)
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
return ReceivedToolResult(
id = toolCall.id,
tool = finishTool.name,
toolArgs = toolCall.contentJsonResult.getOrElse { JsonObject(emptyMap()) },
toolDescription = toolDescription,
content = "Failed to execute '${finishTool.name}' with error: ${e.message}'",
resultKind = ToolResultKind.Failure(e.toAgentError()),
result = null,
)
}

val encodedResult = finishTool.encodeResult(toolResult)
// Append a final tool call result to the prompt for further LLM calls
// to see it (otherwise they would fail)
llm.writeSession {
Expand All @@ -624,7 +642,7 @@ internal suspend fun <Output, OutputTransformed> AIAgentContext.executeFinishToo
toolArgs = toolCall.contentJson,
content = toolCall.content,
resultKind = ToolResultKind.Success,
toolDescription = finishTool.descriptor.description,
toolDescription = toolDescription,
result = encodedResult
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import kotlinx.datetime.Clock
import kotlinx.datetime.Instant.Companion.fromEpochMilliseconds
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue

class ToolCallDescriberTest {
Expand Down Expand Up @@ -162,9 +161,12 @@ class ToolCallDescriberTest {
metaInfo = ResponseMetaInfo.create(testClock),
)

assertFailsWith<Exception> {
describer.describeToolCall(invalidJsonToolCall)
}
val result = describer.describeToolCall(invalidJsonToolCall)

assertTrue(result.content.contains("\"tool_call_id\":\"test-call-id\""))
assertTrue(result.content.contains("\"tool_name\":\"test-tool\""))
assertTrue(result.content.contains("\"tool_args_error\":\"Failed to parse tool arguments:"))
assertEquals(invalidJsonToolCall.metaInfo, result.metaInfo)
}

@Test
Expand Down Expand Up @@ -227,9 +229,13 @@ class ToolCallDescriberTest {
metaInfo = ResponseMetaInfo.create(testClock),
)

assertFailsWith<IllegalArgumentException> {
describer.describeToolCall(nullContentToolCall)
}
val result = describer.describeToolCall(nullContentToolCall)

assertTrue(result.content.contains("\"tool_call_id\":\"test-call-id\""))
assertTrue(result.content.contains("\"tool_name\":\"test-tool\""))
assertTrue(result.content.contains("\"tool_args_error\":\"Failed to parse tool arguments:"))
assertTrue(result.content.contains("IllegalArgumentException"))
assertEquals(nullContentToolCall.metaInfo, result.metaInfo)
}

@Test
Expand Down Expand Up @@ -281,9 +287,12 @@ class ToolCallDescriberTest {
metaInfo = ResponseMetaInfo.create(testClock),
)

assertFailsWith<Exception> {
describer.describeToolCall(nonJsonToolCall)
}
val result = describer.describeToolCall(nonJsonToolCall)

assertTrue(result.content.contains("\"tool_call_id\":\"test-call-id\""))
assertTrue(result.content.contains("\"tool_name\":\"test-tool\""))
assertTrue(result.content.contains("\"tool_args_error\":\"Failed to parse tool arguments:"))
assertEquals(nonJsonToolCall.metaInfo, result.metaInfo)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import ai.koog.agents.core.tools.reflect.asTool
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.ResponseMetaInfo
import kotlinx.datetime.Clock
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KFunction

/**
Expand Down Expand Up @@ -187,11 +188,15 @@ public data class SafeToolFromCallable<TResult>(
* @return A `SafeToolFromCallable.Result` object, either a `Success` with the extracted result
* and content or a `Failure` with an appropriate message.
*/
private fun <TResult> ReceivedToolResult.toSafeResultFromCallable(tool: Tool<ToolFromCallable.VarArgs, TResult>): SafeToolFromCallable.Result<TResult> =
when (result) {
null -> SafeToolFromCallable.Result.Failure(message = content)
else -> SafeToolFromCallable.Result.Success(
result = tool.decodeResult(result),
content = content
)
private fun <TResult> ReceivedToolResult.toSafeResultFromCallable(tool: Tool<ToolFromCallable.VarArgs, TResult>): SafeToolFromCallable.Result<TResult> {
val encodedResult = result ?: return SafeToolFromCallable.Result.Failure(message = content)
val decodedResult = try {
tool.decodeResult(encodedResult)
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
return SafeToolFromCallable.Result.Failure("Tool with name '${tool.name}' failed to deserialize result with error: ${e.message}")
}

return SafeToolFromCallable.Result.Success(result = decodedResult, content = content)
}
Loading