diff --git a/agents/agents-tools/build.gradle.kts b/agents/agents-tools/build.gradle.kts index 01afa7ac1a..12a7d7d35b 100644 --- a/agents/agents-tools/build.gradle.kts +++ b/agents/agents-tools/build.gradle.kts @@ -9,6 +9,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { diff --git a/agents/agents-utils/build.gradle.kts b/agents/agents-utils/build.gradle.kts index 6ce5aef186..39e84bc006 100644 --- a/agents/agents-utils/build.gradle.kts +++ b/agents/agents-utils/build.gradle.kts @@ -9,6 +9,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { @@ -16,12 +17,6 @@ kotlin { } } - commonTest { - dependencies { - implementation(project(":test-utils")) - } - } - jvmTest { dependencies { implementation(kotlin("test-junit5")) diff --git a/build.gradle.kts b/build.gradle.kts index b9658683c0..d59f3293a9 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -106,6 +106,28 @@ subprojects { } subprojects { + tasks.matching { it.name.endsWith("Annotations") && it.name.startsWith("extract") } + .configureEach { enabled = false } + + afterEvaluate { + val releaseTypedef = layout.buildDirectory + .file("intermediates/annotations_typedef_file/release/extractReleaseAnnotations/typedefs.txt") + .get() + .asFile + if (!releaseTypedef.exists()) { + releaseTypedef.parentFile.mkdirs() + releaseTypedef.writeText("") + } + val debugTypedef = layout.buildDirectory + .file("intermediates/annotations_typedef_file/debug/extractDebugAnnotations/typedefs.txt") + .get() + .asFile + if (!debugTypedef.exists()) { + debugTypedef.parentFile.mkdirs() + debugTypedef.writeText("") + } + } + extensions.configure { outputToConsole = true coloredOutput = true @@ -220,37 +242,14 @@ dependencies { dokka(project(":agents:agents-test")) dokka(project(":agents:agents-tools")) dokka(project(":agents:agents-utils")) - dokka(project(":embeddings:embeddings-base")) - dokka(project(":embeddings:embeddings-llm")) - dokka(project(":koog-ktor")) - dokka(project(":koog-spring-boot-starter")) - dokka(project(":prompt:prompt-cache:prompt-cache-files")) - dokka(project(":prompt:prompt-cache:prompt-cache-model")) - dokka(project(":prompt:prompt-cache:prompt-cache-redis")) - dokka(project(":prompt:prompt-executor:prompt-executor-cached")) dokka(project(":prompt:prompt-executor:prompt-executor-clients")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-anthropic-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-bedrock-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-deepseek-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-google-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-mistralai-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-ollama-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client-base")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openrouter-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-dashscope-client")) - dokka(project(":prompt:prompt-executor:prompt-executor-llms")) - dokka(project(":prompt:prompt-executor:prompt-executor-llms-all")) + dokka(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-litertlm-client")) dokka(project(":prompt:prompt-executor:prompt-executor-model")) dokka(project(":prompt:prompt-llm")) dokka(project(":prompt:prompt-markdown")) dokka(project(":prompt:prompt-model")) - dokka(project(":prompt:prompt-processor")) - dokka(project(":prompt:prompt-structure")) - dokka(project(":prompt:prompt-tokenizer")) dokka(project(":prompt:prompt-xml")) dokka(project(":rag:rag-base")) - dokka(project(":rag:vector-storage")) dokka(project(":utils")) } diff --git a/convention-plugin-ai/src/main/kotlin/ai.kotlin.multiplatform.gradle.kts b/convention-plugin-ai/src/main/kotlin/ai.kotlin.multiplatform.gradle.kts index 5d98ea651f..fc6c6138fc 100644 --- a/convention-plugin-ai/src/main/kotlin/ai.kotlin.multiplatform.gradle.kts +++ b/convention-plugin-ai/src/main/kotlin/ai.kotlin.multiplatform.gradle.kts @@ -1,7 +1,7 @@ @file:OptIn(ExperimentalWasmDsl::class) import ai.koog.gradle.publish.maven.configureJvmJarManifest -import ai.koog.gradle.tests.configureTests +import com.android.build.api.variant.LibraryAndroidComponentsExtension import jetbrains.sign.GpgSignSignatoryProvider import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl @@ -15,38 +15,8 @@ plugins { } kotlin { - // Tiers are in accordance with - // Tier 1 - iosSimulatorArm64() - iosX64() - - // Tier 2 - iosArm64() - - // Tier 3 - - // Android androidTarget() - // jvm & js - jvm { - configureTests() - } - - js(IR) { - browser { - binaries.library() - } - - configureTests() - } - - wasmJs { - browser() - nodejs() - binaries.library() - } - sourceSets { androidUnitTest { dependencies { @@ -70,7 +40,27 @@ android { } } -configureJvmJarManifest("jvmJar") +extensions.configure("androidComponents") { + beforeVariants(selector().all()) { variantBuilder -> + val booleanType = Boolean::class.javaPrimitiveType + if (booleanType != null) { + runCatching { + variantBuilder.javaClass.getMethod("setEnableUnitTest", booleanType).invoke(variantBuilder, false) + } + runCatching { + variantBuilder.javaClass.getMethod("setUnitTestEnabled", booleanType).invoke(variantBuilder, false) + } + runCatching { + variantBuilder.javaClass.getMethod("setEnableAndroidTest", booleanType).invoke(variantBuilder, false) + } + runCatching { + variantBuilder.javaClass.getMethod("setAndroidTestEnabled", booleanType).invoke(variantBuilder, false) + } + } + } +} + +tasks.findByName("jvmJar")?.let { configureJvmJarManifest("jvmJar") } val javadocJar by tasks.registering(Jar::class) { archiveClassifier.set("javadoc") diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 7d004283a6..89ad5b8435 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,10 +1,11 @@ [versions] acp = "0.3.0" -agp = "8.12.3" +agp = "8.13.2" annotations = "26.0.2-1" assertj = "3.27.6" awaitility = "4.3.0" aws-sdk-kotlin = "1.5.16" +litertlm = "0.8.0" dokka = "2.1.0" exposed = "0.61.0" h2 = "2.4.240" @@ -14,7 +15,7 @@ jetsign = "45.47" junit = "5.8.2" knit = "0.5.0" kotest = "6.0.7" -kotlin = "2.2.21" +kotlin = "2.3.0" kotlinx-coroutines = "1.10.2" kotlinx-datetime = "0.6.2" kotlinx-io = "0.7.0" @@ -46,6 +47,8 @@ mockito-junit-jupiter = { module = "org.mockito:mockito-junit-jupiter", version. assertj-core = { module = "org.assertj:assertj-core", version.ref = "assertj" } awaitility = { module = "org.awaitility:awaitility-kotlin", version.ref = "awaitility" } junit-jupiter-params = { module = "org.junit.jupiter:junit-jupiter-params", version.ref = "junit" } +junit-jupiter-api = { module = "org.junit.jupiter:junit-jupiter-api", version.ref = "junit" } +junit-jupiter-engine = { module = "org.junit.jupiter:junit-jupiter-engine", version.ref = "junit" } junit-platform-launcher = { module = "org.junit.platform:junit-platform-launcher" } kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" } @@ -114,6 +117,8 @@ aws-sdk-kotlin-bedrock = { module = "aws.sdk.kotlin:bedrock", version.ref = "aws aws-sdk-kotlin-bedrockruntime = { module = "aws.sdk.kotlin:bedrockruntime", version.ref = "aws-sdk-kotlin" } aws-sdk-kotlin-sts = { module = "aws.sdk.kotlin:sts", version.ref = "aws-sdk-kotlin" } android-tools-gradle = { module = "com.android.tools.build:gradle", version.ref = "agp" } +litertlm-jvm = { module = "com.google.ai.edge.litertlm:litertlm-jvm", version.ref = "litertlm" } +litertlm-android = { module = "com.google.ai.edge.litertlm:litertlm-android", version.ref = "litertlm" } # Spring spring-boot-bom = { module = "org.springframework.boot:spring-boot-dependencies", version.ref = "spring-boot" } diff --git a/prompt/prompt-executor/prompt-executor-clients/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/build.gradle.kts index e74a6de399..a2b5af46d0 100644 --- a/prompt/prompt-executor/prompt-executor-clients/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/build.gradle.kts @@ -8,6 +8,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { @@ -29,7 +30,6 @@ kotlin { } commonTest { dependencies { - implementation(project(":test-utils")) } } jvmTest { diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/.gitignore b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/.gitignore new file mode 100644 index 0000000000..2fdd3585f7 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/.gitignore @@ -0,0 +1,10 @@ +# LiteRT-LM native binaries +# These should be obtained separately (see natives/README.md) + +# Native library staging directory +natives/** +!natives/**/.gitkeep + +# Copied native libraries in source sets +src/androidMain/jniLibs/** +!src/androidMain/jniLibs/**/.gitkeep diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/build.gradle.kts new file mode 100644 index 0000000000..955c3b6f18 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/build.gradle.kts @@ -0,0 +1,104 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +plugins { + id("ai.kotlin.multiplatform") + // id("litertlm-natives-convention") // Disabled - natives come from Maven dependency + alias(libs.plugins.kotlin.serialization) +} + +group = rootProject.group +version = rootProject.version + +// Check if Kotlin version is >= 2.3 for experimental features +val kotlinVersionString = libs.versions.kotlin.get() +val kotlinMajorMinor = kotlinVersionString.split(".").take(2).joinToString(".") +val isKotlin23OrHigher = try { + val (major, minor) = kotlinMajorMinor.split(".").map { it.toInt() } + major > 2 || (major == 2 && minor >= 3) +} catch (e: Exception) { + false +} + +kotlin { + // LiteRT-LM has full implementation on JVM/Android, stubs on other platforms + // The convention plugin handles all target declarations + jvm() + + // Enable Kotlin 2.3+ experimental features when available + if (isKotlin23OrHigher) { + compilerOptions { + // @MustUseReturnValues checker - warns when return values are ignored + freeCompilerArgs.add("-Xreturn-value-checker=check") + // Explicit backing fields - allows field keyword in property accessors + freeCompilerArgs.add("-Xexplicit-backing-fields") + } + } + + sourceSets { + commonMain { + dependencies { + api(project(":agents:agents-tools")) + api(project(":prompt:prompt-llm")) + api(project(":prompt:prompt-model")) + api(project(":prompt:prompt-executor:prompt-executor-model")) + api(project(":prompt:prompt-executor:prompt-executor-clients")) + + api(libs.kotlinx.datetime) + api(libs.kotlinx.coroutines.core) + implementation(libs.oshai.kotlin.logging) + } + } + + jvmMain { + dependencies { + // LiteRT-LM JVM dependency + // This is marked as compileOnly - users must add this dependency to their project + // at runtime when they want to use the LiteRT-LM provider. + // See: https://github.com/google-ai-edge/LiteRT-LM + compileOnly(libs.litertlm.jvm) + } + } + + androidMain { + dependencies { + // LiteRT-LM Android dependency + // This is marked as compileOnly - users must add this dependency to their project + // at runtime when they want to use the LiteRT-LM provider. + // Native libraries are bundled separately via the litertlm-natives-convention plugin. + compileOnly(libs.litertlm.android) + } + } + + commonTest { + dependencies { + implementation(project(":test-utils")) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + // Add LiteRT-LM for testing (when available) + implementation(libs.litertlm.jvm) + implementation(libs.mockk) + implementation(libs.kotest.assertions.core) + implementation(libs.junit.jupiter.api) + runtimeOnly(libs.junit.jupiter.engine) + } + } + } + + explicitApi() +} + +// Configure Android native library source directory +android { + sourceSets { + getByName("main") { + jniLibs.srcDirs("src/androidMain/jniLibs") + } + } +} + +publishToMaven() diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/models/.gitignore b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/models/.gitignore new file mode 100644 index 0000000000..b29f831ecd --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/models/.gitignore @@ -0,0 +1,2 @@ +# LiteRT-LM model files (too large for git, download separately) +*.litertlm diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/androidMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.android.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/androidMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.android.kt new file mode 100644 index 0000000000..ef408b7b2d --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/androidMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.android.kt @@ -0,0 +1,21 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.executor.clients.LLMClient + +/** + * Creates a LiteRT-LM client for Android platform. + * + * Note: Android support requires the litertlm-android dependency and native libraries. + * This is a placeholder - full Android implementation coming soon. + */ +public actual fun createLiteRTLMClient(config: LiteRTLMClientConfig): LLMClient { + throw UnsupportedOperationException( + "LiteRT-LM Android support requires litertlm-android dependency. " + + "Add 'com.google.ai.edge.litertlm:litertlm-android' to your dependencies." + ) +} + +/** + * LiteRT-LM is supported on Android (requires additional setup). + */ +public actual fun isLiteRTLMSupported(): Boolean = true diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/appleMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.apple.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/appleMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.apple.kt new file mode 100644 index 0000000000..7b0b291c8c --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/appleMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.apple.kt @@ -0,0 +1,18 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.executor.clients.LLMClient + +/** + * LiteRT-LM is not available on Apple platforms (iOS/macOS). + */ +public actual fun createLiteRTLMClient(config: LiteRTLMClientConfig): LLMClient { + throw UnsupportedOperationException( + "LiteRT-LM is not supported on Apple platforms. " + + "LiteRT-LM only supports JVM and Android." + ) +} + +/** + * LiteRT-LM is not supported on Apple platforms. + */ +public actual fun isLiteRTLMSupported(): Boolean = false diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/commonMain/kotlin/ai/koog/prompt/executor/litertlm/client/Kotlin23Compat.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/commonMain/kotlin/ai/koog/prompt/executor/litertlm/client/Kotlin23Compat.kt new file mode 100644 index 0000000000..0134fd0715 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/commonMain/kotlin/ai/koog/prompt/executor/litertlm/client/Kotlin23Compat.kt @@ -0,0 +1,82 @@ +/** + * Kotlin 2.3 compatibility annotations and utilities. + * + * This file provides forward-compatible annotations that become fully functional + * with Kotlin 2.3's experimental features: + * + * ## Return Value Checker (`-Xreturn-value-checker=check`) + * + * Functions annotated with [@MustUseReturnValue] will generate compiler warnings + * when their return values are ignored. + * + * ## Explicit Backing Fields (`-Xexplicit-backing-fields`) + * + * Kotlin 2.3 introduces explicit backing fields using the `field` keyword. + * This allows separating the backing field type from the property type: + * + * ```kotlin + * // Kotlin 2.3+ syntax (not available in 2.2): + * var items: List + * field = mutableListOf() // Backing field is MutableList + * get() = field.toList() // Expose as immutable + * set(value) { field.clear(); field.addAll(value) } + * + * // Pre-2.3 equivalent (current approach): + * private val _items = mutableListOf() + * val items: List get() = _items.toList() + * ``` + * + * When Kotlin 2.3 is adopted, candidates for explicit backing fields in this module: + * - Properties that need mutable backing but immutable public access + * - Properties with validation in setters + * + * @see Kotlin 2.3 What's New + */ +package ai.koog.prompt.executor.litertlm.client + +/** + * Marks a function whose return value should not be ignored. + * + * On Kotlin 2.3+ with `-Xreturn-value-checker=check`, the compiler will warn + * if the return value of a function marked with this annotation is unused. + * + * On Kotlin < 2.3, this annotation is a no-op but serves as documentation + * for the intended contract. + * + * ## Example + * + * ```kotlin + * @MustUseReturnValue + * suspend fun create(config: Config): Client + * + * // Kotlin 2.3+ will warn: + * create(config) // Warning: Return value of 'create' is unused + * + * // Correct usage: + * val client = create(config) // OK + * ``` + * + * @see Kotlin 2.3 What's New + */ +@Target( + AnnotationTarget.FUNCTION, + AnnotationTarget.CONSTRUCTOR, + AnnotationTarget.PROPERTY_GETTER +) +@Retention(AnnotationRetention.BINARY) +@MustBeDocumented +public annotation class MustUseReturnValue( + val message: String = "The return value of this function should not be ignored" +) + +/** + * Marks a return value that can be safely ignored. + * + * Use this to suppress warnings from [MustUseReturnValue] when ignoring + * the return value is intentional. + * + * On Kotlin < 2.3, this annotation is a no-op. + */ +@Target(AnnotationTarget.EXPRESSION) +@Retention(AnnotationRetention.SOURCE) +public annotation class IgnorableReturnValue diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/commonMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/commonMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.kt new file mode 100644 index 0000000000..f318a34b79 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/commonMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.kt @@ -0,0 +1,284 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.executor.clients.LLMClient +import kotlinx.datetime.Clock + +/** + * Backend for the LiteRT-LM engine. + * + * This is the Kotlin Multiplatform version of the C++'s `litert::lm::Backend`. + */ +public enum class LiteRTLMBackend { + /** CPU LiteRT backend. */ + CPU, + + /** GPU LiteRT backend. */ + GPU, + + /** NPU LiteRT backend. */ + NPU, +} + +/** + * Configuration for the LiteRT-LM engine. + * + * @property modelPath The file path to the LiteRT-LM model. + * @property backend The backend to use for the engine. + * @property visionBackend The backend to use for the vision executor. If null, vision executor will + * not be initialized. + * @property audioBackend The backend to use for the audio executor. If null, audio executor will + * not be initialized. + * @property maxNumTokens The maximum number of the sum of input and output tokens. It is equivalent + * to the size of the kv-cache. When `null`, use the default value from the model or the engine. + * @property cacheDir The directory for placing cache files. It should be a directory where the + * application has write access. If not set, it uses the directory of the [modelPath]. + */ +public data class LiteRTLMEngineConfig( + val modelPath: String, + val backend: LiteRTLMBackend = LiteRTLMBackend.CPU, + val visionBackend: LiteRTLMBackend? = null, + val audioBackend: LiteRTLMBackend? = null, + val maxNumTokens: Int? = null, + val cacheDir: String? = null, +) { + init { + require(maxNumTokens == null || maxNumTokens > 0) { + "maxNumTokens must be positive or null (use the default from model or engine)." + } + } +} + +/** + * Configuration for the sampling process. + * + * @property topK The number of top logits used during sampling. + * @property topP The cumulative probability threshold for nucleus sampling. + * @property temperature The temperature to use for sampling. + * @property seed The seed to use for randomization. Default to 0 (same default as engine code). + */ +public data class LiteRTLMSamplerConfig( + val topK: Int = DEFAULT_TOP_K, + val topP: Double = DEFAULT_TOP_P, + val temperature: Double = DEFAULT_TEMPERATURE, + val seed: Int = 0, +) { + init { + require(topK > 0) { "topK should be positive, but got $topK." } + require(topP >= 0 && topP <= 1) { "topP should between 0 and 1 inclusively, but got $topP." } + require(temperature >= 0) { "temperature should be non-negative, but got $temperature." } + } + + public companion object { + public const val DEFAULT_TOP_K: Int = 40 + public const val DEFAULT_TOP_P: Double = 0.95 + public const val DEFAULT_TEMPERATURE: Double = 0.8 + } +} + +/** + * Configuration for creating a LiteRT-LM client. + * + * @property engineConfig Configuration for the LiteRT-LM engine. + * @property samplerConfig Configuration for the sampling process. If `null`, uses the engine's + * default values. + * @property clock Clock instance used for tracking response metadata timestamps. + */ +public data class LiteRTLMClientConfig( + val engineConfig: LiteRTLMEngineConfig, + val samplerConfig: LiteRTLMSamplerConfig? = null, + val clock: Clock = Clock.System, +) { + /** + * Convenience constructor for simple configurations. + * + * @param modelPath The file path to the LiteRT-LM model. + * @param backend The backend to use for the engine. + * @param cacheDir The directory for placing cache files. + * @param clock Clock instance used for tracking response metadata timestamps. + */ + public constructor( + modelPath: String, + backend: LiteRTLMBackend = LiteRTLMBackend.CPU, + cacheDir: String? = null, + clock: Clock = Clock.System, + ) : this( + engineConfig = LiteRTLMEngineConfig( + modelPath = modelPath, + backend = backend, + cacheDir = cacheDir, + ), + samplerConfig = null, + clock = clock, + ) +} + +/** + * Creates a LiteRT-LM client for on-device LLM inference. + * + * LiteRT-LM is Google's on-device inference engine that enables running LLMs locally. + * This is only supported on JVM and Android platforms. + * + * @param config Configuration for the LiteRT-LM client. + * @return An [LLMClient] implementation for LiteRT-LM. + * @throws UnsupportedOperationException on platforms where LiteRT-LM is not available (iOS, JS, WasmJS). + */ +@MustUseReturnValue("The returned LLMClient must be used and eventually closed") +public expect fun createLiteRTLMClient(config: LiteRTLMClientConfig): LLMClient + +/** + * Checks if LiteRT-LM is supported on the current platform. + * + * @return true if LiteRT-LM is available, false otherwise. + */ +public expect fun isLiteRTLMSupported(): Boolean + +/** + * DSL builder for [LiteRTLMClientConfig]. + * + * Example: + * ```kotlin + * val config = liteRTLMConfig { + * modelPath = "/path/to/model.litertlm" + * backend = LiteRTLMBackend.GPU + * sampler { + * temperature = 0.7 + * topK = 50 + * } + * } + * ``` + */ +@DslMarker +public annotation class LiteRTLMConfigDsl + +/** + * Builder for [LiteRTLMClientConfig]. + */ +@LiteRTLMConfigDsl +public class LiteRTLMClientConfigBuilder { + /** Path to the LiteRT-LM model file. Required. */ + public var modelPath: String? = null + + /** Backend for the main inference engine. */ + public var backend: LiteRTLMBackend = LiteRTLMBackend.CPU + + /** Backend for vision processing, or null to disable. */ + public var visionBackend: LiteRTLMBackend? = null + + /** Backend for audio processing, or null to disable. */ + public var audioBackend: LiteRTLMBackend? = null + + /** Maximum number of tokens (input + output). */ + public var maxNumTokens: Int? = null + + /** Cache directory for model artifacts. */ + public var cacheDir: String? = null + + /** Clock for response timestamps. */ + public var clock: Clock = Clock.System + + private var samplerConfig: LiteRTLMSamplerConfig? = null + + /** + * Configure sampling parameters. + */ + public fun sampler(block: SamplerConfigBuilder.() -> Unit) { + samplerConfig = SamplerConfigBuilder().apply(block).build() + } + + /** + * Builds the [LiteRTLMClientConfig]. + * @throws IllegalStateException if modelPath is not set. + */ + public fun build(): LiteRTLMClientConfig { + val path = requireNotNull(modelPath) { "modelPath must be set" } + return LiteRTLMClientConfig( + engineConfig = LiteRTLMEngineConfig( + modelPath = path, + backend = backend, + visionBackend = visionBackend, + audioBackend = audioBackend, + maxNumTokens = maxNumTokens, + cacheDir = cacheDir, + ), + samplerConfig = samplerConfig, + clock = clock, + ) + } +} + +/** + * Builder for [LiteRTLMSamplerConfig]. + */ +@LiteRTLMConfigDsl +public class SamplerConfigBuilder { + public var topK: Int = LiteRTLMSamplerConfig.DEFAULT_TOP_K + public var topP: Double = LiteRTLMSamplerConfig.DEFAULT_TOP_P + public var temperature: Double = LiteRTLMSamplerConfig.DEFAULT_TEMPERATURE + public var seed: Int = 0 + + public fun build(): LiteRTLMSamplerConfig = LiteRTLMSamplerConfig( + topK = topK, + topP = topP, + temperature = temperature, + seed = seed, + ) +} + +/** + * Creates a [LiteRTLMClientConfig] using a DSL builder. + * + * Example: + * ```kotlin + * val config = liteRTLMConfig { + * modelPath = "/path/to/model.litertlm" + * backend = LiteRTLMBackend.GPU + * sampler { + * temperature = 0.7 + * } + * } + * ``` + */ +public fun liteRTLMConfig(block: LiteRTLMClientConfigBuilder.() -> Unit): LiteRTLMClientConfig { + return LiteRTLMClientConfigBuilder().apply(block).build() +} + +// ==================== Simple Helpers (Cactus-style) ==================== + +/** + * Creates a simple [LiteRTLMClientConfig] with just a model path. + * + * This is the simplest way to get started - just provide the model file path. + * + * Example: + * ```kotlin + * val config = liteRTLMConfig("/path/to/model.litertlm") + * ``` + * + * @param modelPath Path to the LiteRT-LM model file. + * @param backend Backend for inference (default: CPU). + * @return A configured [LiteRTLMClientConfig]. + */ +public fun liteRTLMConfig( + modelPath: String, + backend: LiteRTLMBackend = LiteRTLMBackend.CPU, +): LiteRTLMClientConfig = LiteRTLMClientConfig(modelPath = modelPath, backend = backend) + +/** + * Creates a [LiteRTLMSamplerConfig] with common parameters. + * + * Example: + * ```kotlin + * val sampler = samplerConfig(temperature = 0.7, topK = 50) + * ``` + */ +public fun samplerConfig( + temperature: Double = LiteRTLMSamplerConfig.DEFAULT_TEMPERATURE, + topK: Int = LiteRTLMSamplerConfig.DEFAULT_TOP_K, + topP: Double = LiteRTLMSamplerConfig.DEFAULT_TOP_P, + seed: Int = 0, +): LiteRTLMSamplerConfig = LiteRTLMSamplerConfig( + topK = topK, + topP = topP, + temperature = temperature, + seed = seed, +) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jsMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.js.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jsMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.js.kt new file mode 100644 index 0000000000..32754c852f --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jsMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.js.kt @@ -0,0 +1,18 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.executor.clients.LLMClient + +/** + * LiteRT-LM is not available on JavaScript platforms. + */ +public actual fun createLiteRTLMClient(config: LiteRTLMClientConfig): LLMClient { + throw UnsupportedOperationException( + "LiteRT-LM is not supported on JavaScript. " + + "LiteRT-LM only supports JVM and Android." + ) +} + +/** + * LiteRT-LM is not supported on JavaScript. + */ +public actual fun isLiteRTLMSupported(): Boolean = false diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClient.kt new file mode 100644 index 0000000000..e2f9d57a8b --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClient.kt @@ -0,0 +1,831 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.prompt.dsl.ModerationResult +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.clients.LLMClient +import ai.koog.prompt.executor.clients.LLMClientException +import ai.koog.prompt.llm.LLMCapability +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.AttachmentContent +import ai.koog.prompt.message.ContentPart +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.streaming.StreamFrame +import ai.koog.prompt.streaming.emitAppend +import ai.koog.prompt.streaming.streamFrameFlow +import com.google.ai.edge.litertlm.BenchmarkInfo +import com.google.ai.edge.litertlm.Content +import com.google.ai.edge.litertlm.Conversation +import com.google.ai.edge.litertlm.ConversationConfig +import com.google.ai.edge.litertlm.Engine +import com.google.ai.edge.litertlm.EngineConfig +import com.google.ai.edge.litertlm.ExperimentalApi +import com.google.ai.edge.litertlm.LogSeverity +import com.google.ai.edge.litertlm.SamplerConfig +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext +import kotlinx.datetime.Clock +import java.util.concurrent.atomic.AtomicReference +import com.google.ai.edge.litertlm.Backend as LiteRTBackend +import com.google.ai.edge.litertlm.Message as LiteRTMessage + +/** + * Client for interacting with LiteRT-LM for on-device LLM inference. + * + * LiteRT-LM is Google's on-device inference engine that enables running LLMs locally + * on Android and JVM platforms without requiring network connectivity. + * + * ## Basic Usage (Stateless) + * ```kotlin + * val client = LiteRTLMClient.create( + * LiteRTLMClientConfig(modelPath = "/path/to/model.litertlm") + * ) + * val response = client.execute(prompt, model, tools) + * client.close() + * ``` + * + * ## Multi-turn Conversation (Recommended) + * ```kotlin + * val client = LiteRTLMClient.create(config) + * client.conversation(systemPrompt = "You are helpful.").use { conv -> + * val r1 = conv.send("Hello") + * val r2 = conv.send("Tell me more") // Maintains context + * } + * ``` + * + * @param config The configuration for the client. + */ +private val logger = KotlinLogging.logger { } + +public class LiteRTLMClient private constructor( + private val config: LiteRTLMClientConfig, + private val engine: Engine, + private val toolExecutor: ToolExecutor? = null, +) : LLMClient { + + // Atomic reference for lock-free reads of closed state + private val closed = AtomicReference(false) + + /** Returns `true` if the client has been closed. */ + public fun isClosed(): Boolean = closed.get() + + /** + * Provides the type of Language Learning Model (LLM) provider used by the client. + */ + override fun llmProvider(): LLMProvider = LLMProvider.LiteRTLM + + /** + * Creates a managed conversation for multi-turn interactions. + * + * This is the recommended API for chat-like interactions where context + * should be preserved across multiple messages. + * + * ## Tool Support + * + * If the client was created with a [ToolExecutor], you can enable tool calling + * by passing tool descriptors: + * + * ```kotlin + * client.conversation( + * systemPrompt = "You are helpful.", + * tools = listOf(weatherTool, calculatorTool) + * ).use { conv -> + * val response = conv.send("What's the weather in Paris?") + * } + * ``` + * + * @param systemPrompt Optional system prompt for the conversation. + * @param tools Optional list of tools available to the model. + * @param samplerConfig Optional sampler configuration override. + * @return A [ManagedConversation] that maintains state across turns. + */ + @MustUseReturnValue("The returned ManagedConversation should be used for sending messages") + public fun conversation( + systemPrompt: String? = null, + tools: List = emptyList(), + samplerConfig: LiteRTLMSamplerConfig? = null, + ): ManagedConversation { + checkNotClosed() + + val effectiveSampler = samplerConfig ?: config.samplerConfig + val nativeSamplerConfig = effectiveSampler?.toNative() + ?: SamplerConfig( + topK = LiteRTLMSamplerConfig.DEFAULT_TOP_K, + topP = LiteRTLMSamplerConfig.DEFAULT_TOP_P, + temperature = LiteRTLMSamplerConfig.DEFAULT_TEMPERATURE, + seed = 0, + ) + + // Create tool bridge if tools are provided and executor is available + val (toolBridge, finalSystemPrompt) = if (tools.isNotEmpty() && toolExecutor != null) { + val bridge = LiteRTLMToolBridge(tools, toolExecutor) + val toolsPrompt = bridge.generateToolsPrompt() + val enhancedSystemPrompt = buildString { + if (systemPrompt != null) { + appendLine(systemPrompt) + appendLine() + } + append(toolsPrompt) + } + bridge to enhancedSystemPrompt + } else { + null to systemPrompt + } + + val conversationConfig = ConversationConfig( + systemMessage = finalSystemPrompt?.let { LiteRTMessage.of(it) }, + tools = listOfNotNull(toolBridge), + samplerConfig = nativeSamplerConfig, + ) + + val nativeConversation = engine.createConversation(conversationConfig) + return ManagedConversation(nativeConversation, config.clock) + } + + /** + * Executes a prompt and returns the response. + * + * Note: This creates a new conversation per call. For multi-turn interactions, + * use [conversation] instead to maintain context. + */ + override suspend fun execute( + prompt: Prompt, + model: LLModel, + tools: List + ): List = withContext(Dispatchers.IO) { + require(model.provider == LLMProvider.LiteRTLM) { "Model not supported by LiteRT-LM" } + checkNotClosed() + + try { + val conversationConfig = createConversationConfig(prompt, tools) + + engine.createConversation(conversationConfig).use { conversation -> + val conversationMessages = buildConversationMessages(prompt, model) + + var response: LiteRTMessage? = null + for (message in conversationMessages) { + response = conversation.sendMessage(message) + } + + val responseMetadata = ResponseMetaInfo.create(config.clock) + val responseText = response?.toString() ?: "" + + listOf( + Message.Assistant( + content = responseText, + metaInfo = responseMetadata + ) + ) + } + } catch (e: Exception) { + throw LLMClientException( + clientName = clientName, + message = "Failed to execute prompt: ${e.message}", + cause = e + ) + } + } + + override fun executeStreaming( + prompt: Prompt, + model: LLModel, + tools: List + ): Flow = streamFrameFlow { + require(model.provider == LLMProvider.LiteRTLM) { "Model not supported by LiteRT-LM" } + checkNotClosed() + + val conversationConfig = createConversationConfig(prompt, tools) + + engine.createConversation(conversationConfig).use { conversation -> + val conversationMessages = buildConversationMessages(prompt, model) + + // Send all messages except the last one synchronously to build history + for (i in 0 until conversationMessages.size - 1) { + conversation.sendMessage(conversationMessages[i]) + } + + // Stream the last message + val lastMessage = conversationMessages.lastOrNull() + if (lastMessage != null) { + conversation.sendMessageAsync(lastMessage) + .collect { chunk -> + emitAppend(chunk.toString()) + } + } + } + } + + override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult { + throw LLMClientException( + clientName = clientName, + message = "Moderation is not supported by LiteRT-LM" + ) + } + + /** + * Closes the client and releases all native resources. + * + * After calling this method, the client cannot be used. + */ + override fun close() { + if (closed.compareAndSet(false, true)) { + engine.close() + logger.info { "LiteRT-LM client closed" } + } + } + + private fun createConversationConfig(prompt: Prompt, tools: List): ConversationConfig { + val baseSystemMessage = prompt.messages + .filterIsInstance() + .firstOrNull() + ?.content + + val samplerConfig = config.samplerConfig + val temperature = prompt.params.temperature + ?: samplerConfig?.temperature + ?: LiteRTLMSamplerConfig.DEFAULT_TEMPERATURE + + val finalSamplerConfig = if (samplerConfig != null) { + SamplerConfig( + topK = samplerConfig.topK, + topP = samplerConfig.topP, + temperature = temperature, + seed = samplerConfig.seed, + ) + } else { + SamplerConfig( + topK = LiteRTLMSamplerConfig.DEFAULT_TOP_K, + topP = LiteRTLMSamplerConfig.DEFAULT_TOP_P, + temperature = temperature, + seed = 0, + ) + } + + // Create tool bridge if tools are provided and executor is available + val (toolBridge, systemMessage) = if (tools.isNotEmpty() && toolExecutor != null) { + val bridge = LiteRTLMToolBridge(tools, toolExecutor) + val toolsPrompt = bridge.generateToolsPrompt() + val enhancedSystemMessage = buildString { + if (baseSystemMessage != null) { + appendLine(baseSystemMessage) + appendLine() + } + append(toolsPrompt) + } + bridge to enhancedSystemMessage + } else { + null to baseSystemMessage + } + + return ConversationConfig( + systemMessage = systemMessage?.let { LiteRTMessage.of(it) }, + tools = listOfNotNull(toolBridge), + samplerConfig = finalSamplerConfig, + ) + } + + private fun buildConversationMessages(prompt: Prompt, model: LLModel): List { + return buildList(prompt.messages.size) { + for (message in prompt.messages) { + when (message) { + is Message.System -> { + // Handled in ConversationConfig + } + is Message.User -> { + add(buildUserMessage(message, model)) + } + is Message.Assistant -> { + add(LiteRTMessage.of(message.content)) + } + is Message.Tool.Call -> { + add(LiteRTMessage.of("[Tool Call: ${message.tool}] ${message.content}")) + } + is Message.Tool.Result -> { + add(LiteRTMessage.of("[Tool Result: ${message.tool}] ${message.content}")) + } + is Message.Reasoning -> { + add(LiteRTMessage.of("[Reasoning] ${message.content}")) + } + } + } + }.also { messages -> + if (messages.isEmpty()) { + throw LLMClientException( + clientName = clientName, + message = "Prompt must contain at least one non-system message" + ) + } + } + } + + private fun buildUserMessage(userMessage: Message.User, model: LLModel): LiteRTMessage { + val contentParts = buildList(userMessage.parts.size) { + for (part in userMessage.parts) { + when (part) { + is ContentPart.Text -> add(Content.Text(part.text)) + is ContentPart.Image -> { + if (LLMCapability.Vision.Image !in model.capabilities) { + throw LLMClientException( + clientName = clientName, + message = "Model ${model.id} does not support image inputs" + ) + } + add(convertImageContent(part)) + } + is ContentPart.Audio -> { + if (LLMCapability.Audio !in model.capabilities) { + throw LLMClientException( + clientName = clientName, + message = "Model ${model.id} does not support audio inputs" + ) + } + add(convertAudioContent(part)) + } + is ContentPart.Video -> { + throw LLMClientException( + clientName = clientName, + message = "Video content is not yet supported by LiteRT-LM" + ) + } + is ContentPart.File -> { + val fileText = when (val content = part.content) { + is AttachmentContent.PlainText -> content.text + is AttachmentContent.Binary -> content.asBase64() + is AttachmentContent.URL -> "[File URL: ${content.url}]" + } + add(Content.Text("[File: ${part.fileName ?: "unnamed"}]\n$fileText")) + } + } + } + } + + if (contentParts.isEmpty()) { + throw LLMClientException( + clientName = clientName, + message = "User message must contain at least one content part" + ) + } + + return if (contentParts.size == 1 && contentParts.first() is Content.Text) { + LiteRTMessage.of((contentParts.first() as Content.Text).text) + } else { + LiteRTMessage.of(*contentParts.toTypedArray()) + } + } + + // Using Kotlin 2.2 guard conditions in when for cleaner conditional logic + private fun convertImageContent(image: ContentPart.Image): Content { + val content = image.content + return when (content) { + is AttachmentContent.Binary -> Content.ImageBytes(content.asBytes()) + is AttachmentContent.URL -> { + if (content.url.startsWith("file://")) { + Content.ImageFile(content.url.removePrefix("file://")) + } else { + throw LLMClientException( + clientName = clientName, + message = "Remote image URLs are not supported. Use file:// URLs or binary content." + ) + } + } + is AttachmentContent.PlainText -> throw LLMClientException( + clientName = clientName, + message = "Image cannot have plain text content" + ) + } + } + + private fun convertAudioContent(audio: ContentPart.Audio): Content { + val content = audio.content + return when (content) { + is AttachmentContent.Binary -> Content.AudioBytes(content.asBytes()) + is AttachmentContent.URL -> { + if (content.url.startsWith("file://")) { + Content.AudioFile(content.url.removePrefix("file://")) + } else { + throw LLMClientException( + clientName = clientName, + message = "Remote audio URLs are not supported. Use file:// URLs or binary content." + ) + } + } + is AttachmentContent.PlainText -> throw LLMClientException( + clientName = clientName, + message = "Audio cannot have plain text content" + ) + } + } + + private fun checkNotClosed() { + check(!isClosed()) { "Client has been closed." } + } + + private fun LiteRTLMSamplerConfig.toNative() = SamplerConfig( + topK = topK, + topP = topP, + temperature = temperature, + seed = seed, + ) + + public companion object { + /** + * Creates and initializes a LiteRT-LM client. + * + * This is the recommended way to create a client. The engine is + * initialized before returning, so the client is immediately ready for use. + * + * ## Tool Support + * + * To enable tool calling, provide a [ToolExecutor]: + * + * ```kotlin + * val client = LiteRTLMClient.create(config) { toolName, args -> + * when (toolName) { + * "get_weather" -> getWeather(args["city"] as String) + * "calculate" -> calculate(args["expression"] as String) + * else -> "Unknown tool" + * } + * } + * ``` + * + * @param config Configuration for the client. + * @param toolExecutor Optional callback to execute tools. If provided, tool calling is enabled. + * @return An initialized [LiteRTLMClient]. + * @throws LLMClientException if initialization fails. + */ + @MustUseReturnValue("The returned LiteRTLMClient must be used and eventually closed") + public suspend fun create( + config: LiteRTLMClientConfig, + toolExecutor: ToolExecutor? = null, + ): LiteRTLMClient = withContext(Dispatchers.IO) { + try { + val engineConfig = EngineConfig( + modelPath = config.engineConfig.modelPath, + backend = config.engineConfig.backend.toLiteRTBackend(), + visionBackend = config.engineConfig.visionBackend?.toLiteRTBackend(), + audioBackend = config.engineConfig.audioBackend?.toLiteRTBackend(), + maxNumTokens = config.engineConfig.maxNumTokens, + cacheDir = config.engineConfig.cacheDir, + ) + + val engine = Engine(engineConfig).also { it.initialize() } + logger.info { "LiteRT-LM engine initialized with model: ${config.engineConfig.modelPath}" } + + LiteRTLMClient(config, engine, toolExecutor) + } catch (e: Exception) { + throw LLMClientException( + clientName = "LiteRTLMClient", + message = "Failed to initialize LiteRT-LM engine: ${e.message}", + cause = e + ) + } + } + + /** + * Sets the minimum log severity for all LiteRT-LM native libraries. + */ + public fun setLogSeverity(level: LiteRTLMLogSeverity) { + Engine.Companion.setNativeMinLogServerity(level.toNative()) + } + } +} + +private fun LiteRTLMBackend.toLiteRTBackend(): LiteRTBackend = when (this) { + LiteRTLMBackend.CPU -> LiteRTBackend.CPU + LiteRTLMBackend.GPU -> LiteRTBackend.GPU + LiteRTLMBackend.NPU -> LiteRTBackend.NPU +} + +/** + * Log severity levels for LiteRT-LM native libraries. + */ +public enum class LiteRTLMLogSeverity { + VERBOSE, + DEBUG, + INFO, + WARNING, + ERROR, + FATAL, + SILENT; + + internal fun toNative(): LogSeverity = when (this) { + VERBOSE -> LogSeverity.VERBOSE + DEBUG -> LogSeverity.DEBUG + INFO -> LogSeverity.INFO + WARNING -> LogSeverity.WARNING + ERROR -> LogSeverity.ERROR + FATAL -> LogSeverity.FATAL + SILENT -> LogSeverity.INFINITY + } +} + +/** + * A managed conversation that maintains state across multiple turns. + * + * This class wraps a native LiteRT-LM [Conversation] and provides a + * higher-level API for multi-turn interactions. + * + * Example: + * ```kotlin + * client.conversation(systemPrompt = "You are helpful.").use { conv -> + * val greeting = conv.send("Hello!") + * val followUp = conv.send("Tell me more about that.") + * } + * ``` + */ +/** + * Represents an entry in the conversation history. + * + * @property role The role of the message sender. + * @property content The text content of the message. + * @property timestamp When the message was sent/received. + */ +public data class ConversationEntry( + val role: Role, + val content: String, + val timestamp: kotlinx.datetime.Instant, +) { + public enum class Role { USER, ASSISTANT } +} + +public class ManagedConversation internal constructor( + private val conversation: Conversation, + private val clock: Clock, +) : AutoCloseable { + + /** + * Mutex to ensure thread-safe access to the conversation. + * + * LiteRT-LM conversations are stateful and not thread-safe for concurrent + * message sending. This mutex ensures only one message is processed at a time. + */ + private val mutex = Mutex() + + /** Internal mutable history. */ + private val _history = mutableListOf() + + /** + * The conversation history as an immutable list. + * + * Each entry contains the role (USER or ASSISTANT), content, and timestamp. + * Useful for debugging, logging, or displaying conversation context. + */ + public val history: List + get() = _history.toList() + + /** Number of user messages sent in this conversation. */ + public val messageCount: Int + get() = _history.count { it.role == ConversationEntry.Role.USER } + + /** Number of turns (user + assistant pairs) in this conversation. */ + public val turnCount: Int + get() = _history.size / 2 + + // ==================== Text Messages ==================== + + /** + * Sends a text message and returns the response. + * + * This method is thread-safe. Concurrent calls will be serialized. + * + * @param text The message text to send. + * @return The assistant's response. + */ + @MustUseReturnValue("The assistant's response should be processed") + public suspend fun send(text: String): Message.Assistant = mutex.withLock { + _history.add(ConversationEntry(ConversationEntry.Role.USER, text, clock.now())) + val response = conversation.sendMessage(LiteRTMessage.of(text)) + val responseText = response.toString() + _history.add(ConversationEntry(ConversationEntry.Role.ASSISTANT, responseText, clock.now())) + Message.Assistant( + content = responseText, + metaInfo = ResponseMetaInfo.create(clock), + ) + } + + /** + * Sends a text message and streams the response. + * + * This method acquires a lock for the duration of streaming. + * Concurrent calls will wait until the stream completes. + * + * @param text The message text to send. + * @return A flow of response chunks. + */ + @MustUseReturnValue("The returned Flow must be collected to receive the response") + public fun sendStreaming(text: String): Flow = kotlinx.coroutines.flow.flow { + mutex.withLock { + _history.add(ConversationEntry(ConversationEntry.Role.USER, text, clock.now())) + val responseBuilder = StringBuilder() + conversation.sendMessageAsync(LiteRTMessage.of(text)) + .collect { message -> + val chunk = message.toString() + responseBuilder.append(chunk) + emit(chunk) + } + _history.add(ConversationEntry(ConversationEntry.Role.ASSISTANT, responseBuilder.toString(), clock.now())) + } + } + + // ==================== Image Messages ==================== + + /** + * Sends an image with optional text and returns the response. + * + * This method is thread-safe. Concurrent calls will be serialized. + * + * @param imageBytes The image data as bytes. + * @param text Optional text to accompany the image. + * @return The assistant's response. + */ + @MustUseReturnValue("The assistant's response should be processed") + public suspend fun sendImage(imageBytes: ByteArray, text: String? = null): Message.Assistant = + sendMultimodal( + historyDescription = text ?: "[Image: ${imageBytes.size} bytes]", + text = text, + mediaContent = Content.ImageBytes(imageBytes), + ) + + /** + * Sends an image from a file path with optional text and returns the response. + * + * This method is thread-safe. Concurrent calls will be serialized. + * + * @param imagePath The file path to the image. + * @param text Optional text to accompany the image. + * @return The assistant's response. + */ + @MustUseReturnValue("The assistant's response should be processed") + public suspend fun sendImageFile(imagePath: String, text: String? = null): Message.Assistant = + sendMultimodal( + historyDescription = text ?: "[Image: $imagePath]", + text = text, + mediaContent = Content.ImageFile(imagePath), + ) + + /** + * Sends an image with optional text and streams the response. + * + * @param imageBytes The image data as bytes. + * @param text Optional text to accompany the image. + * @return A flow of response chunks. + */ + @MustUseReturnValue("The returned Flow must be collected to receive the response") + public fun sendImageStreaming(imageBytes: ByteArray, text: String? = null): Flow = + sendMultimodalStreaming( + historyDescription = text ?: "[Image: ${imageBytes.size} bytes]", + text = text, + mediaContent = Content.ImageBytes(imageBytes), + ) + + // ==================== Audio Messages ==================== + + /** + * Sends audio with optional text and returns the response. + * + * This method is thread-safe. Concurrent calls will be serialized. + * + * @param audioBytes The audio data as bytes. + * @param text Optional text to accompany the audio. + * @return The assistant's response. + */ + @MustUseReturnValue("The assistant's response should be processed") + public suspend fun sendAudio(audioBytes: ByteArray, text: String? = null): Message.Assistant = + sendMultimodal( + historyDescription = text ?: "[Audio: ${audioBytes.size} bytes]", + text = text, + mediaContent = Content.AudioBytes(audioBytes), + ) + + /** + * Sends audio from a file path with optional text and returns the response. + * + * This method is thread-safe. Concurrent calls will be serialized. + * + * @param audioPath The file path to the audio. + * @param text Optional text to accompany the audio. + * @return The assistant's response. + */ + @MustUseReturnValue("The assistant's response should be processed") + public suspend fun sendAudioFile(audioPath: String, text: String? = null): Message.Assistant = + sendMultimodal( + historyDescription = text ?: "[Audio: $audioPath]", + text = text, + mediaContent = Content.AudioFile(audioPath), + ) + + /** + * Sends audio with optional text and streams the response. + * + * @param audioBytes The audio data as bytes. + * @param text Optional text to accompany the audio. + * @return A flow of response chunks. + */ + @MustUseReturnValue("The returned Flow must be collected to receive the response") + public fun sendAudioStreaming(audioBytes: ByteArray, text: String? = null): Flow = + sendMultimodalStreaming( + historyDescription = text ?: "[Audio: ${audioBytes.size} bytes]", + text = text, + mediaContent = Content.AudioBytes(audioBytes), + ) + + // ==================== Internal Helpers ==================== + + /** + * Internal helper for sending multimodal content (image/audio) synchronously. + * Reduces code duplication across sendImage, sendImageFile, sendAudio, sendAudioFile. + */ + private suspend fun sendMultimodal( + historyDescription: String, + text: String?, + mediaContent: Content, + ): Message.Assistant = mutex.withLock { + _history.add(ConversationEntry(ConversationEntry.Role.USER, historyDescription, clock.now())) + + val contents = buildList { + if (text != null) add(Content.Text(text)) + add(mediaContent) + } + val response = conversation.sendMessage(LiteRTMessage.of(*contents.toTypedArray())) + val responseText = response.toString() + _history.add(ConversationEntry(ConversationEntry.Role.ASSISTANT, responseText, clock.now())) + + Message.Assistant( + content = responseText, + metaInfo = ResponseMetaInfo.create(clock), + ) + } + + /** + * Internal helper for sending multimodal content with streaming response. + * Reduces code duplication across sendImageStreaming, sendAudioStreaming. + */ + private fun sendMultimodalStreaming( + historyDescription: String, + text: String?, + mediaContent: Content, + ): Flow = kotlinx.coroutines.flow.flow { + mutex.withLock { + _history.add(ConversationEntry(ConversationEntry.Role.USER, historyDescription, clock.now())) + + val contents = buildList { + if (text != null) add(Content.Text(text)) + add(mediaContent) + } + val responseBuilder = StringBuilder() + conversation.sendMessageAsync(LiteRTMessage.of(*contents.toTypedArray())) + .collect { message -> + val chunk = message.toString() + responseBuilder.append(chunk) + emit(chunk) + } + _history.add( + ConversationEntry(ConversationEntry.Role.ASSISTANT, responseBuilder.toString(), clock.now()) + ) + } + } + + // ==================== Control Methods ==================== + + /** + * Cancels any ongoing inference in this conversation. + * + * This can be called concurrently with [send] or [sendStreaming] to + * abort a long-running inference. + */ + public fun cancel() { + conversation.cancelProcess() + } + + /** + * Gets benchmark information for this conversation. + * + * @return Benchmark metrics including timing and token counts. + */ + @MustUseReturnValue("Benchmark info should be used for performance analysis") + @OptIn(ExperimentalApi::class) + public suspend fun getBenchmarkInfo(): BenchmarkInfo = mutex.withLock { + conversation.getBenchmarkInfo() + } + + /** + * Clears the local history tracking. + * + * Note: This does NOT clear the conversation context in the native engine. + * The model will still remember previous messages. To start fresh, create + * a new conversation. + */ + public fun clearLocalHistory() { + _history.clear() + } + + override fun close() { + conversation.close() + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.jvm.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.jvm.kt new file mode 100644 index 0000000000..bc6989874b --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.jvm.kt @@ -0,0 +1,16 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.executor.clients.LLMClient +import kotlinx.coroutines.runBlocking + +/** + * Creates a LiteRT-LM client for JVM platform. + */ +public actual fun createLiteRTLMClient(config: LiteRTLMClientConfig): LLMClient { + return runBlocking { LiteRTLMClient.create(config) } +} + +/** + * LiteRT-LM is supported on JVM. + */ +public actual fun isLiteRTLMSupported(): Boolean = true diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMToolBridge.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMToolBridge.kt new file mode 100644 index 0000000000..34e6bd225a --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMToolBridge.kt @@ -0,0 +1,264 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolParameterType +import com.google.ai.edge.litertlm.Tool +import com.google.ai.edge.litertlm.ToolParam +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.booleanOrNull +import kotlinx.serialization.json.doubleOrNull +import kotlinx.serialization.json.intOrNull +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.longOrNull + +/** + * Functional interface for executing tools. + * + * Implementations should handle tool execution synchronously or use appropriate + * blocking mechanisms for async operations. + */ +public fun interface ToolExecutor { + /** + * Executes a tool with the given name and arguments. + * + * @param toolName The name of the tool to execute. + * @param arguments The parsed arguments as a map. + * @return The result of the tool execution as a string. + */ + public fun execute(toolName: String, arguments: Map): String +} + +/** + * Bridges Koog's [ToolDescriptor] to LiteRT-LM's annotation-based tool system. + * + * This class acts as a dispatcher that exposes a single `@Tool` annotated method + * to LiteRT-LM. When the model calls this tool, it delegates to the appropriate + * Koog tool via the provided [ToolExecutor]. + * + * ## Usage + * + * ```kotlin + * val tools = listOf(weatherTool.descriptor, calculatorTool.descriptor) + * val bridge = LiteRTLMToolBridge(tools) { name, args -> + * when (name) { + * "get_weather" -> getWeather(args["city"] as String) + * "calculate" -> calculate(args["expression"] as String) + * else -> "Unknown tool: $name" + * } + * } + * + * // Pass bridge to conversation + * val config = ConversationConfig(tools = listOf(bridge)) + * ``` + * + * @param tools The list of Koog tool descriptors to expose. + * @param executor The callback to execute when a tool is called. + */ +public class LiteRTLMToolBridge( + private val tools: List, + private val executor: ToolExecutor, +) { + private val json = Json { ignoreUnknownKeys = true } + + /** + * The tool method that LiteRT-LM will call. + * + * This is a dispatcher that routes to the appropriate Koog tool based on the + * tool name. The model will see this as the only available tool, but the + * system prompt includes descriptions of all available tools. + * + * @param toolName The name of the tool to call (must match a tool in [tools]). + * @param argsJson JSON string containing the arguments for the tool. + * @return The result of the tool execution, or an error message. + */ + @Tool(description = "Execute a tool. Use the tool name and provide arguments as JSON.") + public fun callTool( + @ToolParam(description = "The exact name of the tool to call") + toolName: String, + @ToolParam(description = "JSON object with the tool's arguments") + argsJson: String, + ): String { + val tool = tools.find { it.name == toolName } + ?: return buildErrorResponse( + "Unknown tool: '$toolName'. Available tools: ${tools.joinToString { it.name }}" + ) + + val arguments = try { + parseArguments(argsJson, tool) + } catch (e: Exception) { + return buildErrorResponse("Invalid arguments for '$toolName': ${e.message}") + } + + return try { + executor.execute(toolName, arguments) + } catch (e: Exception) { + buildErrorResponse("Tool execution failed: ${e.message}") + } + } + + /** + * Generates a system prompt section describing all available tools. + * + * Include this in the system prompt so the model knows what tools are available + * and how to call them. + * + * @return A formatted string describing all tools and their parameters. + */ + public fun generateToolsPrompt(): String = buildString { + appendLine("## Available Tools") + appendLine() + appendLine("You can use the following tools by calling `callTool` with the tool name and arguments:") + appendLine() + + for (tool in tools) { + appendLine("### ${tool.name}") + appendLine(tool.description) + appendLine() + + val allParams = tool.requiredParameters + tool.optionalParameters + if (allParams.isNotEmpty()) { + appendLine("**Parameters:**") + for (param in tool.requiredParameters) { + appendLine("- `${param.name}` (${formatType(param.type)}, required): ${param.description}") + } + for (param in tool.optionalParameters) { + appendLine("- `${param.name}` (${formatType(param.type)}, optional): ${param.description}") + } + appendLine() + } + + // Example usage + appendLine("**Example:**") + appendLine("```json") + appendLine(generateExampleCall(tool)) + appendLine("```") + appendLine() + } + } + + private fun parseArguments(argsJson: String, tool: ToolDescriptor): Map { + if (argsJson.isBlank() || argsJson == "{}") { + return emptyMap() + } + + val jsonObject = json.parseToJsonElement(argsJson).jsonObject + val result = mutableMapOf() + + // Parse required parameters + for (param in tool.requiredParameters) { + val value = jsonObject[param.name] + ?: throw IllegalArgumentException("Missing required parameter: ${param.name}") + result[param.name] = convertJsonToKotlin(value, param.type) + } + + // Parse optional parameters + for (param in tool.optionalParameters) { + val value = jsonObject[param.name] + if (value != null && value !is JsonNull) { + result[param.name] = convertJsonToKotlin(value, param.type) + } + } + + return result + } + + private fun convertJsonToKotlin(element: JsonElement, type: ToolParameterType): Any? { + return when (type) { + is ToolParameterType.String -> element.jsonPrimitive.content + is ToolParameterType.Null -> null + is ToolParameterType.Integer -> + element.jsonPrimitive.intOrNull + ?: element.jsonPrimitive.longOrNull?.toInt() + ?: element.jsonPrimitive.content.toIntOrNull() + is ToolParameterType.Float -> + element.jsonPrimitive.doubleOrNull + ?: element.jsonPrimitive.content.toDoubleOrNull() + is ToolParameterType.Boolean -> + element.jsonPrimitive.booleanOrNull + ?: element.jsonPrimitive.content.toBooleanStrictOrNull() + is ToolParameterType.List -> { + element.jsonArray.map { item -> + convertJsonToKotlin(item, type.itemsType) + } + } + is ToolParameterType.Object -> { + val obj = element.jsonObject + type.properties.associate { prop -> + prop.name to obj[prop.name]?.let { convertJsonToKotlin(it, prop.type) } + } + } + is ToolParameterType.Enum -> element.jsonPrimitive.content + is ToolParameterType.AnyOf -> { + for (descriptor in type.types) { + val value = runCatching { convertJsonToKotlin(element, descriptor.type) }.getOrNull() + if (value != null) { + return value + } + } + element.jsonPrimitive.content + } + } + } + + private fun formatType(type: ToolParameterType): String = when (type) { + is ToolParameterType.String -> "string" + is ToolParameterType.Null -> "null" + is ToolParameterType.Integer -> "integer" + is ToolParameterType.Float -> "float" + is ToolParameterType.Boolean -> "boolean" + is ToolParameterType.List -> "array<${formatType(type.itemsType)}>" + is ToolParameterType.Object -> "object" + is ToolParameterType.Enum -> "enum(${type.entries.joinToString("|")})" + is ToolParameterType.AnyOf -> "anyOf(${type.types.joinToString { formatType(it.type) }})" + } + + private fun generateExampleCall(tool: ToolDescriptor): String { + val exampleArgs = buildMap { + for (param in tool.requiredParameters) { + put(param.name, generateExampleValue(param.type)) + } + } + return """{"toolName": "${tool.name}", "argsJson": "${json.encodeToString(JsonObject.serializer(), JsonObject(exampleArgs.mapValues { jsonValueOf(it.value) }))}"}""" + } + + private fun generateExampleValue(type: ToolParameterType): Any = when (type) { + is ToolParameterType.String -> "example" + is ToolParameterType.Null -> "null" + is ToolParameterType.Integer -> 42 + is ToolParameterType.Float -> 3.14 + is ToolParameterType.Boolean -> true + is ToolParameterType.List -> listOf(generateExampleValue(type.itemsType)) + is ToolParameterType.Object -> mapOf("key" to "value") + is ToolParameterType.Enum -> type.entries.firstOrNull() ?: "value" + is ToolParameterType.AnyOf -> generateExampleValue(type.types.first().type) + } + + private fun jsonValueOf(value: Any?): JsonElement = when (value) { + null -> JsonNull + is String -> JsonPrimitive(value) + is Number -> JsonPrimitive(value) + is Boolean -> JsonPrimitive(value) + is List<*> -> JsonArray( + value.map { item -> + jsonValueOf(item) + } + ) + is Map<*, *> -> JsonObject( + value.entries.associate { (k, v) -> + k.toString() to jsonValueOf(v) + } + ) + else -> JsonPrimitive(value.toString()) + } + + private fun buildErrorResponse(message: String): String { + return """{"error": "$message"}""" + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMChatCompletionTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMChatCompletionTest.kt new file mode 100644 index 0000000000..c4dbdca601 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMChatCompletionTest.kt @@ -0,0 +1,305 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolParameterDescriptor +import ai.koog.agents.core.tools.ToolParameterType +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.llm.LiteRTLMModels +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.io.File + +/** + * End-to-end chat completion test using LiteRT-LM provider and executor. + * + * This test demonstrates the complete flow: + * 1. Create a LiteRT-LM client with tool support + * 2. Send a prompt with available tools + * 3. Model calls tools via the bridge + * 4. Tool executor handles the calls + * 5. Responses are streamed back + * + * ## Running this test + * + * ```bash + * LITERTLM_MODEL_PATH=/path/to/gemma-3n-e4b.litertlm \ + * ./gradlew :prompt:prompt-executor:prompt-executor-clients:prompt-executor-litertlm-client:jvmTest \ + * --tests "*.LiteRTLMChatCompletionTest" + * ``` + */ +class LiteRTLMChatCompletionTest { + + private val modelPath: String by lazy { + System.getenv("LITERTLM_MODEL_PATH") + ?: throw IllegalStateException("LITERTLM_MODEL_PATH environment variable not set") + } + + private var client: LiteRTLMClient? = null + + // Mock tool implementations + private val weatherData = mapOf( + "paris" to """{"city": "Paris", "temperature": 22, "condition": "sunny"}""", + "london" to """{"city": "London", "temperature": 15, "condition": "cloudy"}""", + "tokyo" to """{"city": "Tokyo", "temperature": 28, "condition": "humid"}""" + ) + + private val calculatorResults = mutableListOf() + + // Tool descriptors + private val weatherTool = ToolDescriptor( + name = "get_weather", + description = "Get the current weather for a city. Returns temperature and conditions.", + requiredParameters = listOf( + ToolParameterDescriptor( + name = "city", + description = "The city name (e.g., Paris, London, Tokyo)", + type = ToolParameterType.String + ) + ) + ) + + private val calculatorTool = ToolDescriptor( + name = "calculate", + description = "Evaluate a mathematical expression and return the result.", + requiredParameters = listOf( + ToolParameterDescriptor( + name = "expression", + description = "The math expression to evaluate (e.g., '2 + 2', '10 * 5')", + type = ToolParameterType.String + ) + ) + ) + + // Tool executor that handles tool calls + private val toolExecutor = ToolExecutor { toolName, args -> + when (toolName) { + "get_weather" -> { + val city = (args["city"] as? String)?.lowercase() ?: "unknown" + weatherData[city] ?: """{"error": "City not found: $city"}""" + } + "calculate" -> { + val expression = args["expression"] as? String ?: "" + try { + // Simple evaluation for basic operations + val result = evaluateSimpleExpression(expression) + calculatorResults.add("$expression = $result") + """{"expression": "$expression", "result": $result}""" + } catch (e: Exception) { + """{"error": "Cannot evaluate: $expression"}""" + } + } + else -> """{"error": "Unknown tool: $toolName"}""" + } + } + + @BeforeEach + fun setup() { + require(File(modelPath).exists()) { "Model file not found: $modelPath" } + LiteRTLMClient.setLogSeverity(LiteRTLMLogSeverity.ERROR) + calculatorResults.clear() + } + + @AfterEach + fun teardown() { + client?.close() + client = null + } + + @Test + fun `chat completion without tools`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + val chatPrompt = prompt("chat-no-tools") { + system("You are a helpful assistant. Be concise.") + user("What is the capital of France?") + } + + val responses = client!!.execute(chatPrompt, LiteRTLMModels.Google.GEMMA_3N_E4B) + + assert(responses.isNotEmpty()) { "Expected a response" } + val answer = responses.first().content + println("Answer: $answer") + assert(answer.contains("Paris", ignoreCase = true)) { + "Expected answer to mention Paris" + } + } + + @Test + fun `chat completion with streaming`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + val chatPrompt = prompt("chat-streaming") { + system("You are a helpful assistant.") + user("Count from 1 to 5, one number per line.") + } + + val chunks = mutableListOf() + val frames = client!!.executeStreaming(chatPrompt, LiteRTLMModels.Google.GEMMA_3N_E4B) + .toList() + + frames.forEach { frame -> + // Collect streaming content + println(frame) + } + + assert(frames.isNotEmpty()) { "Expected streaming frames" } + } + + @Test + fun `chat completion with tool calling - weather`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config, toolExecutor) + + val chatPrompt = prompt("chat-tools-weather") { + system( + """ + You are a helpful assistant with access to tools. + When asked about weather, use the get_weather tool. + Respond naturally after getting the tool result. + """.trimIndent() + ) + user("What's the weather like in Paris right now?") + } + + val responses = client!!.execute( + chatPrompt, + LiteRTLMModels.Google.GEMMA_3N_E4B, + tools = listOf(weatherTool) + ) + + assert(responses.isNotEmpty()) { "Expected a response" } + val answer = responses.first().content + println("Weather response: $answer") + + // The model should either call the tool and respond, or mention weather + assert(answer.isNotBlank()) { "Expected non-empty response" } + } + + @Test + fun `chat completion with tool calling - calculator`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config, toolExecutor) + + val chatPrompt = prompt("chat-tools-calculator") { + system( + """ + You are a helpful assistant with access to a calculator tool. + When asked to do math, use the calculate tool. + """.trimIndent() + ) + user("What is 42 multiplied by 7?") + } + + val responses = client!!.execute( + chatPrompt, + LiteRTLMModels.Google.GEMMA_3N_E4B, + tools = listOf(calculatorTool) + ) + + assert(responses.isNotEmpty()) { "Expected a response" } + val answer = responses.first().content + println("Calculator response: $answer") + println("Calculator was called with: $calculatorResults") + } + + @Test + fun `chat completion with multiple tools available`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config, toolExecutor) + + val chatPrompt = prompt("chat-tools-multiple") { + system( + """ + You are a helpful assistant with access to tools: + - get_weather: Get current weather for a city + - calculate: Do math calculations + + Use the appropriate tool when needed. + """.trimIndent() + ) + user("What's the temperature in London, and what's 15 + 7?") + } + + val responses = client!!.execute( + chatPrompt, + LiteRTLMModels.Google.GEMMA_3N_E4B, + tools = listOf(weatherTool, calculatorTool) + ) + + assert(responses.isNotEmpty()) { "Expected a response" } + println("Multi-tool response: ${responses.first().content}") + } + + @Test + fun `multi-turn conversation with tools`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config, toolExecutor) + + client!!.conversation( + systemPrompt = "You are a helpful weather assistant. Use tools when asked about weather.", + tools = listOf(weatherTool, calculatorTool) + ).use { conv -> + // First turn + val r1 = conv.send("Hi! Can you check the weather in Tokyo?") + println("Turn 1: ${r1.content}") + + // Second turn - follow-up + val r2 = conv.send("Is it hotter there than in London?") + println("Turn 2: ${r2.content}") + + // Third turn - unrelated + val r3 = conv.send("Thanks! What's your favorite color?") + println("Turn 3: ${r3.content}") + + assert(r1.content.isNotBlank()) + assert(r2.content.isNotBlank()) + assert(r3.content.isNotBlank()) + } + } + + @Test + fun `streaming chat with managed conversation`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + client!!.conversation(systemPrompt = "You are a storyteller.").use { conv -> + print("Story: ") + conv.sendStreaming("Tell me a very short story about a robot.") + .collect { chunk -> + print(chunk) + } + println() + } + } + + // Simple expression evaluator for basic math + private fun evaluateSimpleExpression(expr: String): Double { + val cleaned = expr.replace(" ", "") + + // Handle basic operations: +, -, *, / + return when { + cleaned.contains("+") -> { + val parts = cleaned.split("+") + parts.sumOf { it.toDouble() } + } + cleaned.contains("*") -> { + val parts = cleaned.split("*") + parts.fold(1.0) { acc, s -> acc * s.toDouble() } + } + cleaned.contains("-") && !cleaned.startsWith("-") -> { + val parts = cleaned.split("-") + parts.drop(1).fold(parts.first().toDouble()) { acc, s -> acc - s.toDouble() } + } + cleaned.contains("/") -> { + val parts = cleaned.split("/") + parts.drop(1).fold(parts.first().toDouble()) { acc, s -> acc / s.toDouble() } + } + else -> cleaned.toDouble() + } + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientIntegrationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientIntegrationTest.kt new file mode 100644 index 0000000000..4dd2fddaef --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientIntegrationTest.kt @@ -0,0 +1,226 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.llm.LiteRTLMModels +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.io.File + +/** + * Integration tests for LiteRTLMClient. + * + * These tests require: + * 1. The LiteRT-LM library to be available on the classpath + * 2. A valid model file (e.g., gemma-3n-e4b.litertlm) + * + * ## Running these tests + * + * 1. Download a compatible model from https://github.com/google-ai-edge/LiteRT-LM + * 2. Set the `LITERTLM_MODEL_PATH` environment variable to the model file location + * 3. Remove the `@Disabled` annotation from the test class + * + * Example: + * ```bash + * LITERTLM_MODEL_PATH=/path/to/gemma-3n-e4b.litertlm \ + * ./gradlew :prompt:prompt-executor:prompt-executor-clients:prompt-executor-litertlm-client:jvmTest \ + * --tests "*.LiteRTLMClientIntegrationTest" + * ``` + */ +class LiteRTLMClientIntegrationTest { + + private val modelPath: String by lazy { + System.getenv("LITERTLM_MODEL_PATH") + ?: throw IllegalStateException( + "LITERTLM_MODEL_PATH environment variable not set. " + + "Please set it to the path of your .litertlm model file." + ) + } + + private var client: LiteRTLMClient? = null + + @BeforeEach + fun setup() { + require(File(modelPath).exists()) { "Model file not found: $modelPath" } + + // Silence native logging for cleaner test output + LiteRTLMClient.setLogSeverity(LiteRTLMLogSeverity.ERROR) + } + + @AfterEach + fun teardown() { + client?.close() + client = null + } + + @Test + fun `can initialize client with factory method`() = runTest { + val config = LiteRTLMClientConfig( + modelPath = modelPath, + backend = LiteRTLMBackend.CPU + ) + + client = LiteRTLMClient.create(config) + + // Client should be ready to use immediately after create() + assert(!client!!.isClosed()) { "Client should not be closed after creation" } + } + + @Test + fun `can execute simple prompt`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + val testPrompt = prompt("simple-prompt") { + system("You are a helpful assistant. Keep responses brief.") + user("What is 2 + 2?") + } + + val responses = client!!.execute(testPrompt, LiteRTLMModels.Google.GEMMA_3N_E4B) + + assert(responses.isNotEmpty()) { "Expected at least one response" } + println("Response: ${responses.first().content}") + } + + @Test + fun `can stream responses`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + val testPrompt = prompt("streaming-prompt") { + system("You are a helpful assistant.") + user("Count from 1 to 5.") + } + + val frames = client!!.executeStreaming(testPrompt, LiteRTLMModels.Google.GEMMA_3N_E4B) + .toList() + + assert(frames.isNotEmpty()) { "Expected streaming frames" } + println("Received ${frames.size} streaming frames") + } + + @Test + fun `can use managed conversation for multi-turn chat`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + client!!.conversation(systemPrompt = "You are a helpful assistant.").use { conv -> + // First turn + val greeting = conv.send("Hello! My name is Alice.") + println("Assistant: ${greeting.content}") + + // Second turn - should remember context + val followUp = conv.send("What is my name?") + println("Assistant: ${followUp.content}") + + // The response should mention "Alice" since we're in a conversation + assert(followUp.content.contains("Alice", ignoreCase = true)) { + "Expected response to remember the name 'Alice' from conversation context" + } + } + } + + @Test + fun `can stream responses in managed conversation`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + client!!.conversation(systemPrompt = "You are a helpful assistant.").use { conv -> + val chunks = mutableListOf() + + conv.sendStreaming("Tell me a very short joke.") + .collect { chunk -> + chunks.add(chunk) + print(chunk) // Stream to console + } + println() // newline after streaming + + assert(chunks.isNotEmpty()) { "Expected streaming chunks" } + } + } + + @Test + fun `can use GPU backend`() = runTest { + val config = LiteRTLMClientConfig( + modelPath = modelPath, + backend = LiteRTLMBackend.GPU + ) + + client = LiteRTLMClient.create(config) + + val testPrompt = prompt("gpu-backend-prompt") { + user("Hello!") + } + + val responses = client!!.execute(testPrompt, LiteRTLMModels.Google.GEMMA_3N_E4B) + assert(responses.isNotEmpty()) + } + + @Test + fun `can configure custom sampler settings`() = runTest { + val config = LiteRTLMClientConfig( + engineConfig = LiteRTLMEngineConfig( + modelPath = modelPath, + backend = LiteRTLMBackend.CPU, + ), + samplerConfig = LiteRTLMSamplerConfig( + topK = 20, // More focused sampling + topP = 0.8, // Narrower nucleus + temperature = 0.5, // More deterministic + seed = 42, // Reproducible results + ) + ) + + client = LiteRTLMClient.create(config) + + val testPrompt = prompt("sampler-settings-prompt") { + user("What color is the sky?") + } + + val responses = client!!.execute(testPrompt, LiteRTLMModels.Google.GEMMA_3N_E4B) + assert(responses.isNotEmpty()) + println("Response with custom sampler: ${responses.first().content}") + } + + @Test + fun `client can be closed and reports closed state`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + assert(!client!!.isClosed()) { "Client should not be closed initially" } + + client!!.close() + + assert(client!!.isClosed()) { "Client should be closed after close()" } + } + + @Test + fun `conversation can be cancelled`() = runTest { + val config = LiteRTLMClientConfig(modelPath = modelPath) + client = LiteRTLMClient.create(config) + + client!!.conversation().use { conv -> + // Start a potentially long generation + val job = launch { + try { + conv.sendStreaming("Write a 1000 word essay about the history of computing.") + .collect { /* consume */ } + } catch (e: Exception) { + // Expected if cancelled + } + } + + // Give it a moment to start + kotlinx.coroutines.delay(100) + + // Cancel the inference + conv.cancel() + + // Wait for the job to complete (should be quick after cancel) + job.join() + } + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientTest.kt new file mode 100644 index 0000000000..1cb29b14ee --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientTest.kt @@ -0,0 +1,139 @@ +package ai.koog.prompt.executor.litertlm.client + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test + +/** + * Unit tests for LiteRTLMClient configuration and validation. + * + * Note: Integration tests require the LiteRT-LM library and a valid model file. + * See [LiteRTLMClientIntegrationTest] for tests that use an actual model. + */ +class LiteRTLMClientTest { + + @Test + fun `LiteRTLMBackend enum has CPU, GPU, and NPU values`() { + LiteRTLMBackend.CPU shouldBe LiteRTLMBackend.CPU + LiteRTLMBackend.GPU shouldBe LiteRTLMBackend.GPU + LiteRTLMBackend.NPU shouldBe LiteRTLMBackend.NPU + } + + @Test + fun `LiteRTLMEngineConfig validates maxNumTokens`() { + // Null is valid (use default) + LiteRTLMEngineConfig(modelPath = "/path/to/model", maxNumTokens = null) + + // Positive values are valid + LiteRTLMEngineConfig(modelPath = "/path/to/model", maxNumTokens = 1024) + LiteRTLMEngineConfig(modelPath = "/path/to/model", maxNumTokens = 1) + + // Zero and negative should throw + shouldThrow { + LiteRTLMEngineConfig(modelPath = "/path/to/model", maxNumTokens = 0) + } + shouldThrow { + LiteRTLMEngineConfig(modelPath = "/path/to/model", maxNumTokens = -1) + } + } + + @Test + fun `LiteRTLMSamplerConfig validates parameters`() { + // Valid configuration + LiteRTLMSamplerConfig(topK = 40, topP = 0.95, temperature = 0.8) + + // topK must be positive + shouldThrow { + LiteRTLMSamplerConfig(topK = 0) + } + shouldThrow { + LiteRTLMSamplerConfig(topK = -1) + } + + // topP must be between 0 and 1 + shouldThrow { + LiteRTLMSamplerConfig(topP = -0.1) + } + shouldThrow { + LiteRTLMSamplerConfig(topP = 1.1) + } + + // temperature must be non-negative + shouldThrow { + LiteRTLMSamplerConfig(temperature = -0.1) + } + + // Edge cases that should be valid + LiteRTLMSamplerConfig(topP = 0.0) + LiteRTLMSamplerConfig(topP = 1.0) + LiteRTLMSamplerConfig(temperature = 0.0) + } + + @Test + fun `LiteRTLMClientConfig convenience constructor works`() { + val config = LiteRTLMClientConfig( + modelPath = "/path/to/model.litertlm", + backend = LiteRTLMBackend.GPU, + cacheDir = "/tmp/cache" + ) + + config.engineConfig.modelPath shouldBe "/path/to/model.litertlm" + config.engineConfig.backend shouldBe LiteRTLMBackend.GPU + config.engineConfig.cacheDir shouldBe "/tmp/cache" + config.samplerConfig shouldBe null + } + + @Test + fun `LiteRTLMClientConfig full constructor works`() { + val engineConfig = LiteRTLMEngineConfig( + modelPath = "/path/to/model.litertlm", + backend = LiteRTLMBackend.NPU, + visionBackend = LiteRTLMBackend.GPU, + audioBackend = LiteRTLMBackend.CPU, + maxNumTokens = 8192, + cacheDir = "/tmp/cache" + ) + + val samplerConfig = LiteRTLMSamplerConfig( + topK = 50, + topP = 0.9, + temperature = 1.0, + seed = 42 + ) + + val config = LiteRTLMClientConfig( + engineConfig = engineConfig, + samplerConfig = samplerConfig + ) + + config.engineConfig.backend shouldBe LiteRTLMBackend.NPU + config.engineConfig.visionBackend shouldBe LiteRTLMBackend.GPU + config.engineConfig.audioBackend shouldBe LiteRTLMBackend.CPU + config.engineConfig.maxNumTokens shouldBe 8192 + config.samplerConfig?.topK shouldBe 50 + config.samplerConfig?.seed shouldBe 42 + } + + @Test + fun `LiteRTLMLogSeverity has all expected levels`() { + LiteRTLMLogSeverity.VERBOSE + LiteRTLMLogSeverity.DEBUG + LiteRTLMLogSeverity.INFO + LiteRTLMLogSeverity.WARNING + LiteRTLMLogSeverity.ERROR + LiteRTLMLogSeverity.FATAL + LiteRTLMLogSeverity.SILENT + } + + @Test + fun `isLiteRTLMSupported returns true on JVM`() { + isLiteRTLMSupported() shouldBe true + } + + @Test + fun `LiteRTLMSamplerConfig default values match official API`() { + LiteRTLMSamplerConfig.DEFAULT_TOP_K shouldBe 40 + LiteRTLMSamplerConfig.DEFAULT_TOP_P shouldBe 0.95 + LiteRTLMSamplerConfig.DEFAULT_TEMPERATURE shouldBe 0.8 + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMToolBridgeTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMToolBridgeTest.kt new file mode 100644 index 0000000000..8524f53ee2 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/jvmTest/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMToolBridgeTest.kt @@ -0,0 +1,232 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolParameterDescriptor +import ai.koog.agents.core.tools.ToolParameterType +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test + +class LiteRTLMToolBridgeTest { + + private val weatherTool = ToolDescriptor( + name = "get_weather", + description = "Get the current weather for a location", + requiredParameters = listOf( + ToolParameterDescriptor( + name = "city", + description = "The city name", + type = ToolParameterType.String + ) + ), + optionalParameters = listOf( + ToolParameterDescriptor( + name = "units", + description = "Temperature units (celsius or fahrenheit)", + type = ToolParameterType.String + ) + ) + ) + + private val calculatorTool = ToolDescriptor( + name = "calculate", + description = "Evaluate a mathematical expression", + requiredParameters = listOf( + ToolParameterDescriptor( + name = "expression", + description = "The math expression to evaluate", + type = ToolParameterType.String + ) + ) + ) + + @Test + fun `callTool dispatches to executor with correct arguments`() { + var capturedName: String? = null + var capturedArgs: Map? = null + + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { name, args -> + capturedName = name + capturedArgs = args + """{"temperature": 25, "unit": "celsius"}""" + } + + val result = bridge.callTool("get_weather", """{"city": "Paris"}""") + + capturedName shouldBe "get_weather" + capturedArgs shouldBe mapOf("city" to "Paris") + result shouldContain "temperature" + } + + @Test + fun `callTool handles optional parameters`() { + var capturedArgs: Map? = null + + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { _, args -> + capturedArgs = args + "OK" + } + + bridge.callTool("get_weather", """{"city": "London", "units": "fahrenheit"}""") + + capturedArgs shouldBe mapOf("city" to "London", "units" to "fahrenheit") + } + + @Test + fun `callTool returns error for unknown tool`() { + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { _, _ -> "OK" } + + val result = bridge.callTool("unknown_tool", "{}") + + result shouldContain "error" + result shouldContain "Unknown tool" + result shouldContain "get_weather" + } + + @Test + fun `callTool returns error for invalid JSON`() { + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { _, _ -> "OK" } + + val result = bridge.callTool("get_weather", "not valid json") + + result shouldContain "error" + result shouldContain "Invalid arguments" + } + + @Test + fun `callTool returns error for missing required parameter`() { + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { _, _ -> "OK" } + + val result = bridge.callTool("get_weather", """{"units": "celsius"}""") + + result shouldContain "error" + result shouldContain "city" + } + + @Test + fun `callTool handles executor exception`() { + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { _, _ -> + throw RuntimeException("Network error") + } + + val result = bridge.callTool("get_weather", """{"city": "Paris"}""") + + result shouldContain "error" + result shouldContain "execution failed" + } + + @Test + fun `generateToolsPrompt includes all tools`() { + val bridge = LiteRTLMToolBridge(listOf(weatherTool, calculatorTool)) { _, _ -> "OK" } + + val prompt = bridge.generateToolsPrompt() + + prompt shouldContain "get_weather" + prompt shouldContain "Get the current weather" + prompt shouldContain "city" + prompt shouldContain "calculate" + prompt shouldContain "mathematical expression" + } + + @Test + fun `generateToolsPrompt marks required vs optional parameters`() { + val bridge = LiteRTLMToolBridge(listOf(weatherTool)) { _, _ -> "OK" } + + val prompt = bridge.generateToolsPrompt() + + prompt shouldContain "city" + prompt shouldContain "required" + prompt shouldContain "units" + prompt shouldContain "optional" + } + + @Test + fun `callTool handles integer parameters`() { + val toolWithInt = ToolDescriptor( + name = "repeat", + description = "Repeat text", + requiredParameters = listOf( + ToolParameterDescriptor("text", "Text to repeat", ToolParameterType.String), + ToolParameterDescriptor("count", "Number of times", ToolParameterType.Integer) + ) + ) + + var capturedArgs: Map? = null + val bridge = LiteRTLMToolBridge(listOf(toolWithInt)) { _, args -> + capturedArgs = args + "OK" + } + + bridge.callTool("repeat", """{"text": "hello", "count": 3}""") + + capturedArgs?.get("count") shouldBe 3 + } + + @Test + fun `callTool handles boolean parameters`() { + val toolWithBool = ToolDescriptor( + name = "format", + description = "Format text", + requiredParameters = listOf( + ToolParameterDescriptor("text", "Text to format", ToolParameterType.String), + ToolParameterDescriptor("uppercase", "Convert to uppercase", ToolParameterType.Boolean) + ) + ) + + var capturedArgs: Map? = null + val bridge = LiteRTLMToolBridge(listOf(toolWithBool)) { _, args -> + capturedArgs = args + "OK" + } + + bridge.callTool("format", """{"text": "hello", "uppercase": true}""") + + capturedArgs?.get("uppercase") shouldBe true + } + + @Test + fun `callTool handles array parameters`() { + val toolWithArray = ToolDescriptor( + name = "sum", + description = "Sum numbers", + requiredParameters = listOf( + ToolParameterDescriptor( + "numbers", + "Numbers to sum", + ToolParameterType.List(ToolParameterType.Integer) + ) + ) + ) + + var capturedArgs: Map? = null + val bridge = LiteRTLMToolBridge(listOf(toolWithArray)) { _, args -> + capturedArgs = args + "OK" + } + + bridge.callTool("sum", """{"numbers": [1, 2, 3, 4, 5]}""") + + @Suppress("UNCHECKED_CAST") + val numbers = capturedArgs?.get("numbers") as? List + numbers shouldBe listOf(1, 2, 3, 4, 5) + } + + @Test + fun `callTool handles empty arguments`() { + val toolNoArgs = ToolDescriptor( + name = "get_time", + description = "Get current time" + ) + + var called = false + val bridge = LiteRTLMToolBridge(listOf(toolNoArgs)) { _, args -> + called = true + args shouldBe emptyMap() + "12:00" + } + + bridge.callTool("get_time", "{}") + + called shouldBe true + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/wasmJsMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.wasmJs.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/wasmJsMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.wasmJs.kt new file mode 100644 index 0000000000..4bf9e0fce0 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-litertlm-client/src/wasmJsMain/kotlin/ai/koog/prompt/executor/litertlm/client/LiteRTLMClientFactory.wasmJs.kt @@ -0,0 +1,18 @@ +package ai.koog.prompt.executor.litertlm.client + +import ai.koog.prompt.executor.clients.LLMClient + +/** + * LiteRT-LM is not available on WasmJS platform. + */ +public actual fun createLiteRTLMClient(config: LiteRTLMClientConfig): LLMClient { + throw UnsupportedOperationException( + "LiteRT-LM is not supported on WasmJS. " + + "LiteRT-LM only supports JVM and Android." + ) +} + +/** + * LiteRT-LM is not supported on WasmJS. + */ +public actual fun isLiteRTLMSupported(): Boolean = false diff --git a/prompt/prompt-executor/prompt-executor-model/build.gradle.kts b/prompt/prompt-executor/prompt-executor-model/build.gradle.kts index b8d953f1e4..ae7adb3a90 100644 --- a/prompt/prompt-executor/prompt-executor-model/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-model/build.gradle.kts @@ -8,6 +8,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { diff --git a/prompt/prompt-llm/build.gradle.kts b/prompt/prompt-llm/build.gradle.kts index 7885230751..1dac67c128 100644 --- a/prompt/prompt-llm/build.gradle.kts +++ b/prompt/prompt-llm/build.gradle.kts @@ -9,6 +9,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { diff --git a/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LLMProvider.kt b/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LLMProvider.kt index 5b3c657596..ccdd7167e0 100644 --- a/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LLMProvider.kt +++ b/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LLMProvider.kt @@ -137,4 +137,18 @@ public abstract class LLMProvider(public val id: String, public val display: Str */ @Serializable public data object MistralAI : LLMProvider("mistralai", "MistralAI") + + /** + * Represents the LiteRT-LM provider within the available set of large language model providers. + * + * LiteRT-LM is Google's on-device inference engine for running LLMs locally on Android and JVM platforms. + * It is identified by its unique ID ("litertlm") and display name ("LiteRT-LM"). + * It extends the `LLMProvider` sealed class, which serves as a base class for all supported language model providers. + * + * This data object adheres to the structure and serialization requirements defined by the parent class. + * It is part of the available LLM provider hierarchy, which is used to configure and identify specific + * providers for large language model functionalities and capabilities. + */ + @Serializable + public data object LiteRTLM : LLMProvider("litertlm", "LiteRT-LM") } diff --git a/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LiteRTLMModels.kt b/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LiteRTLMModels.kt new file mode 100644 index 0000000000..4f111717d5 --- /dev/null +++ b/prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LiteRTLMModels.kt @@ -0,0 +1,45 @@ +package ai.koog.prompt.llm + +/** + * Represents a collection of predefined Large Language Models (LLM) available for LiteRT-LM. + * LiteRT-LM is Google's on-device inference engine for running LLMs locally. + * + * Each model is configured with specific capabilities and context lengths suitable for on-device inference. + */ +public object LiteRTLMModels { + /** + * The Google object represents the configuration for Google's large language models (LLMs) + * that can be run on-device via LiteRT-LM. + * It contains the predefined model specifications for Google's LLMs, including their identifiers + * and supported capabilities. + */ + public object Google { + /** + * Represents the Gemma 3n E4B model optimized for on-device inference. + * + * Gemma 3n is a lightweight, multimodal model designed for efficient on-device execution. + * The E4B variant uses 4-bit quantization for reduced memory footprint while maintaining + * quality suitable for mobile and edge deployment. + * + * This model supports: + * - Temperature adjustment for controlling response randomness + * - Vision capabilities for image understanding + * - Audio capabilities for audio processing + * + * Note: Tool support is available through LiteRT-LM's function calling mechanism. + * + * @see LiteRT-LM GitHub + */ + public val GEMMA_3N_E4B: LLModel = LLModel( + provider = LLMProvider.LiteRTLM, + id = "gemma-3n-e4b", + capabilities = listOf( + LLMCapability.Temperature, + LLMCapability.Vision.Image, + LLMCapability.Audio, + LLMCapability.Tools + ), + contextLength = 32_768, + ) + } +} diff --git a/prompt/prompt-markdown/build.gradle.kts b/prompt/prompt-markdown/build.gradle.kts index d34156a4ac..a50d06086c 100644 --- a/prompt/prompt-markdown/build.gradle.kts +++ b/prompt/prompt-markdown/build.gradle.kts @@ -8,6 +8,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { diff --git a/prompt/prompt-model/build.gradle.kts b/prompt/prompt-model/build.gradle.kts index ca7c6d5ff8..db0995362e 100644 --- a/prompt/prompt-model/build.gradle.kts +++ b/prompt/prompt-model/build.gradle.kts @@ -9,6 +9,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { @@ -23,7 +24,6 @@ kotlin { commonTest { dependencies { - implementation(project(":test-utils")) api(project(":prompt:prompt-markdown")) } } diff --git a/prompt/prompt-xml/build.gradle.kts b/prompt/prompt-xml/build.gradle.kts index d34156a4ac..a50d06086c 100644 --- a/prompt/prompt-xml/build.gradle.kts +++ b/prompt/prompt-xml/build.gradle.kts @@ -8,6 +8,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { diff --git a/rag/rag-base/build.gradle.kts b/rag/rag-base/build.gradle.kts index d77a5afa74..2588340efa 100644 --- a/rag/rag-base/build.gradle.kts +++ b/rag/rag-base/build.gradle.kts @@ -9,6 +9,7 @@ plugins { } kotlin { + jvm() sourceSets { commonMain { dependencies { diff --git a/settings.gradle.kts b/settings.gradle.kts index 921c704b18..f64e245adc 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -12,7 +12,6 @@ pluginManagement { include(":agents:agents-core") include(":agents:agents-ext") include(":agents:agents-planner") - include(":agents:agents-features:agents-features-acp") include(":agents:agents-features:agents-features-event-handler") include(":agents:agents-features:agents-features-memory") @@ -21,75 +20,17 @@ include(":agents:agents-features:agents-features-sql") include(":agents:agents-features:agents-features-trace") include(":agents:agents-features:agents-features-tokenizer") include(":agents:agents-features:agents-features-snapshot") -include(":agents:agents-features:agents-features-a2a-core") -include(":agents:agents-features:agents-features-a2a-server") -include(":agents:agents-features:agents-features-a2a-client") - include(":agents:agents-mcp") -include(":agents:agents-mcp-server") include(":agents:agents-test") include(":agents:agents-tools") include(":agents:agents-utils") - -include(":integration-tests") - -include(":koog-agents") - -include(":prompt:prompt-cache:prompt-cache-files") -include(":prompt:prompt-cache:prompt-cache-model") -include(":prompt:prompt-cache:prompt-cache-redis") - -include(":prompt:prompt-executor:prompt-executor-cached") - +include(":test-utils") include(":prompt:prompt-executor:prompt-executor-clients") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-anthropic-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-bedrock-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-deepseek-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-google-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-mistralai-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-ollama-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client-base") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openrouter-client") -include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-dashscope-client") - -include(":prompt:prompt-executor:prompt-executor-llms") -include(":prompt:prompt-executor:prompt-executor-llms-all") +include(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-litertlm-client") include(":prompt:prompt-executor:prompt-executor-model") - include(":prompt:prompt-llm") include(":prompt:prompt-markdown") include(":prompt:prompt-model") -include(":prompt:prompt-processor") -include(":prompt:prompt-structure") -include(":prompt:prompt-tokenizer") include(":prompt:prompt-xml") - -include(":embeddings:embeddings-base") -include(":embeddings:embeddings-llm") - include(":rag:rag-base") -include(":rag:vector-storage") - -include(":a2a:a2a-core") -include(":a2a:a2a-server") -include(":a2a:a2a-client") -include(":a2a:a2a-test") -include(":a2a:a2a-transport:a2a-transport-core-jsonrpc") -include(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http") -include(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http") -include(":a2a:test-tck:a2a-test-server-tck") - -include(":http-client:http-client-core") -include(":http-client:http-client-test") -include(":http-client:http-client-ktor") -include(":http-client:http-client-okhttp") -include(":http-client:http-client-java") - -include(":koog-spring-boot-starter") - -include(":koog-ktor") -include(":docs") - -include(":test-utils") include(":utils") diff --git a/utils/build.gradle.kts b/utils/build.gradle.kts index 4f48b9646a..b538acc854 100644 --- a/utils/build.gradle.kts +++ b/utils/build.gradle.kts @@ -9,6 +9,7 @@ group = rootProject.group version = rootProject.version kotlin { + jvm() sourceSets { commonMain { dependencies { @@ -19,12 +20,6 @@ kotlin { } } - commonTest { - dependencies { - implementation(project(":test-utils")) - } - } - jvmTest { dependencies { implementation(kotlin("test-junit5"))