diff --git a/CHANGELOG.md b/CHANGELOG.md index 28caa31..07cdaa4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,40 @@ +# 1.0.0 + +- Added: + - Methods: + - `startBot()` + - `startBotAndConnect()` + - `appendToContext()` + - `sendClientMessage()` + - `sendClientRequest()` + - `registerFunctionCallHandler()` + - `unregisterFunctionCallHandler()` + - `unregisterAllFunctionCallHandlers()` + - `disconnectBot()` + - Callbacks: + - `onBotLlmSearchResponse()` + - `onLLMFunctionCall()` +- Removed: + - Helper classes (`LLMHelper`, `RTVIClientHelper`) and associated methods (`registerHelper`, `unregisterHelper`) + - Server configuration, including `getConfig()`, `updateConfig()`, `describeConfig()` and associated types + - Actions, including `action()`, `describeActions()` and associated types + - `expiry` field on client and transports + - `onStorageItemStored()` callback + - `RTVIClientParams` + - `sendWithResponse()` + - `sendMessage()` +- Changed: + - `connect()` is modified to no longer make a POST request to the backend, but rather pass the + specified `Value` straight to the `Transport`. See `startBotAndConnect()` for a helper method + which also includes the POST request. + - `Transport` now passed directly into `PipecatClientOptions` rather than using factory + - `onBotReady()` now receives a `BotReadyData` parameter +- Renamed: + - `RTVIClient` -> `PipecatClient` + - `RTVIEventCallbacks` -> `PipecatEventCallbacks` + - `RTVIClientOptions` -> `PipecatClientOptions` + - `onPipecatMetrics()` -> `onMetrics()` + # 0.3.4 - Added `onServerMessage` callback diff --git a/README.md b/README.md index 02c1ffc..6d1a5ef 100644 --- a/README.md +++ b/README.md @@ -10,20 +10,20 @@ This Android library contains the core components and types needed to set up an When building an RTVI application, you should use the transport-specific client library (see [here](https://rtvi.mintlify.app/api-reference/transports/introduction) for available first-party -packages.) The base `RTVIClient` has no transport included. +packages.) The base `PipecatClient` has no transport included. ## Usage Add the following dependency to your `build.gradle` file: ``` -implementation "ai.pipecat:client:0.3.4" +implementation "ai.pipecat:client:1.0.0" ``` -Then instantiate the `RTVIClient` from your code, specifying the backend `baseUrl` and transport. +Then instantiate the `PipecatClient` from your code: ```kotlin -val callbacks = object : RTVIEventCallbacks() { +val callbacks = object : PipecatEventCallbacks() { override fun onBackendError(message: String) { Log.e(TAG, "Error from backend: $message") @@ -32,12 +32,12 @@ val callbacks = object : RTVIEventCallbacks() { // ... } -val client = RTVIClient(transport, callbacks, options) +val client = PipecatClient(transport, options) -client.start().withCallback { +client.startBotAndConnect(startBotParams).withCallback { // ... } ``` -`client.start()` (and other APIs) return a `Future`, which can give callbacks, or be awaited -using Kotlin Coroutines (`client.start().await()`). +Many `PipecatClient` APIs return a `Future`, which can give callbacks, or be awaited +using Kotlin Coroutines (`client.startBotAndConnect().await()`). diff --git a/pipecat-client-android/build.gradle.kts b/pipecat-client-android/build.gradle.kts index b45e1c2..7e6f6d4 100644 --- a/pipecat-client-android/build.gradle.kts +++ b/pipecat-client-android/build.gradle.kts @@ -60,7 +60,7 @@ publishing { register("release") { groupId = "ai.pipecat" artifactId = "client" - version = "0.3.4" + version = "1.0.0" pom { name.set("Pipecat Client") diff --git a/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatClient.kt b/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatClient.kt new file mode 100644 index 0000000..11639c3 --- /dev/null +++ b/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatClient.kt @@ -0,0 +1,533 @@ +package ai.pipecat.client + +import ai.pipecat.client.result.Future +import ai.pipecat.client.result.Promise +import ai.pipecat.client.result.RTVIError +import ai.pipecat.client.result.catchExceptions +import ai.pipecat.client.result.resolvedPromiseErr +import ai.pipecat.client.result.resolvedPromiseOk +import ai.pipecat.client.transport.MsgClientToServer +import ai.pipecat.client.transport.MsgServerToClient +import ai.pipecat.client.transport.Transport +import ai.pipecat.client.transport.TransportContext +import ai.pipecat.client.types.APIRequest +import ai.pipecat.client.types.AppendToContextResultData +import ai.pipecat.client.types.BotReadyData +import ai.pipecat.client.types.DataMessage +import ai.pipecat.client.types.LLMContextMessage +import ai.pipecat.client.types.LLMFunctionCallData +import ai.pipecat.client.types.LLMFunctionCallHandler +import ai.pipecat.client.types.LLMFunctionCallResult +import ai.pipecat.client.types.MediaDeviceId +import ai.pipecat.client.types.Transcript +import ai.pipecat.client.types.TransportState +import ai.pipecat.client.types.Value +import ai.pipecat.client.utils.JSON_INSTANCE +import ai.pipecat.client.utils.ResponseWaiters +import ai.pipecat.client.utils.ThreadRef +import ai.pipecat.client.utils.post +import android.util.Log +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.jsonObject +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.RequestBody.Companion.toRequestBody +import java.util.UUID + +internal const val RTVI_PROTOCOL_VERSION = "1.0.0" + +/** + * A Pipecat client. Connects to an RTVI backend and handles bidirectional audio and video + * streaming. + * + * The client must be cleaned up using the [release] method when it is no longer required. + * + * @param transport Transport for media streaming. + * @param callbacks Callbacks invoked when changes occur in the voice session. + * @param options Additional options for configuring the client and backend. + */ +@Suppress("unused") +open class PipecatClient, ConnectParams>( + private val transport: TransportType, + private val options: PipecatClientOptions, +) { + companion object { + private const val TAG = "PipecatClient" + } + + /** + * The thread used by the PipecatClient for callbacks and other operations. + */ + val thread = ThreadRef.forCurrent() + + private val responseWaiters = ResponseWaiters(thread) + private val functionCallHandlers = mutableMapOf() + + private val transportCtx = object : TransportContext { + + override val options + get() = this@PipecatClient.options + + override val callbacks + get() = options.callbacks + + override val thread = this@PipecatClient.thread + + override fun onConnectionEnd() { + thread.runOnThread { + responseWaiters.clearAll() + connection?.ready?.resolveErr(RTVIError.OperationCancelled) + connection = null + } + } + + override fun onMessage(msg: MsgServerToClient) = thread.runOnThread { + + try { + when (msg.type) { + MsgServerToClient.Type.BotReady -> { + + val data = + JSON_INSTANCE.decodeFromJsonElement(msg.data) + + this@PipecatClient.transport.setState(TransportState.Ready) + + connection?.ready?.resolveOk(Unit) + + callbacks.onBotReady(data) + } + + MsgServerToClient.Type.Error -> { + val data = + JSON_INSTANCE.decodeFromJsonElement(msg.data) + callbacks.onBackendError(data.error) + } + + MsgServerToClient.Type.ServerResponse, + MsgServerToClient.Type.AppendToContextResult -> { + try { + responseWaiters.resolve(id = msg.id!!, data = msg.data) + } catch (e: Exception) { + Log.e(TAG, "Got exception handling server response", e) + callbacks.onBackendError("Got exception while handling server response, see log (id = ${msg.id})") + } + } + + MsgServerToClient.Type.ErrorResponse -> { + + val data = + JSON_INSTANCE.decodeFromJsonElement( + msg.data + ) + + try { + responseWaiters.reject( + id = msg.id!!, + error = RTVIError.ErrorResponse(data.error) + ) + } catch (e: Exception) { + Log.e(TAG, "Got exception handling error response", e) + callbacks.onBackendError(data.error) + } + } + + MsgServerToClient.Type.UserTranscription -> { + val data = JSON_INSTANCE.decodeFromJsonElement(msg.data) + callbacks.onUserTranscript(data) + } + + MsgServerToClient.Type.BotTranscription, + MsgServerToClient.Type.BotTranscriptionLegacy -> { + val text = (msg.data.jsonObject.get("text") as JsonPrimitive).content + callbacks.onBotTranscript(text) + } + + MsgServerToClient.Type.UserStartedSpeaking -> { + callbacks.onUserStartedSpeaking() + } + + MsgServerToClient.Type.UserStoppedSpeaking -> { + callbacks.onUserStoppedSpeaking() + } + + MsgServerToClient.Type.BotStartedSpeaking -> { + callbacks.onBotStartedSpeaking() + } + + MsgServerToClient.Type.BotStoppedSpeaking -> { + callbacks.onBotStoppedSpeaking() + } + + MsgServerToClient.Type.BotLlmText -> { + val data: MsgServerToClient.Data.BotLLMTextData = + JSON_INSTANCE.decodeFromJsonElement(msg.data) + + callbacks.onBotLLMText(data) + } + + MsgServerToClient.Type.BotTtsText -> { + val data: MsgServerToClient.Data.BotTTSTextData = + JSON_INSTANCE.decodeFromJsonElement(msg.data) + + callbacks.onBotTTSText(data) + } + + MsgServerToClient.Type.BotLlmStarted -> callbacks.onBotLLMStarted() + MsgServerToClient.Type.BotLlmStopped -> callbacks.onBotLLMStopped() + + MsgServerToClient.Type.BotTtsStarted -> callbacks.onBotTTSStarted() + MsgServerToClient.Type.BotTtsStopped -> callbacks.onBotTTSStopped() + + MsgServerToClient.Type.ServerMessage -> { + callbacks.onServerMessage(JSON_INSTANCE.decodeFromJsonElement(msg.data)) + } + + MsgServerToClient.Type.Metrics -> { + callbacks.onMetrics(JSON_INSTANCE.decodeFromJsonElement(msg.data)) + } + + MsgServerToClient.Type.LlmFunctionCall -> { + + val functionCallData = + JSON_INSTANCE.decodeFromJsonElement(msg.data) + + callbacks.onLLMFunctionCall(functionCallData) + + val handler = functionCallHandlers[functionCallData.functionName] + + if (handler != null) { + val activeConnection = connection + + handler.handleFunctionCall(functionCallData) { resultData -> + + thread.runOnThread { + if (activeConnection == connection) { + sendMessage( + MsgClientToServer.LlmFunctionCallResult( + msgId = msg.id ?: UUID.randomUUID().toString(), + data = LLMFunctionCallResult( + functionName = functionCallData.functionName, + toolCallID = functionCallData.toolCallID, + arguments = functionCallData.args, + result = JSON_INSTANCE.encodeToJsonElement( + resultData + ) + ) + ) + ) + } + } + } + } + } + + MsgServerToClient.Type.BotLlmSearchResponse -> { + callbacks.onBotLLMSearchResponse(JSON_INSTANCE.decodeFromJsonElement(msg.data)) + } + + else -> { + Log.w(TAG, "Unexpected message type '${msg.type}'") + } + } + } catch (e: Exception) { + Log.e(TAG, "Exception while handling VoiceMessage", e) + } + } + } + + private inner class Connection { + val ready = Promise(thread) + } + + private var connection: Connection? = null + + init { + transport.initialize(transportCtx) + } + + /** + * Initialize local media devices such as camera and microphone. + * + * @return A Future, representing the asynchronous result of this operation. + */ + fun initDevices(): Future = transport.initDevices() + + fun startBot(startBotParams: APIRequest): Future = + thread.runOnThreadReturningFuture { + + when (transport.state()) { + TransportState.Authorizing, + TransportState.Connecting, + TransportState.Connected, + TransportState.Ready -> return@runOnThreadReturningFuture resolvedPromiseErr( + thread, + RTVIError.InvalidState( + expected = TransportState.Initialized, + actual = transport.state() + ) + ) + + else -> { + // Continue + } + } + + transport.setState(TransportState.Authorizing) + + val postResult = post( + thread = thread, + url = startBotParams.endpoint, + body = JSON_INSTANCE.encodeToString(startBotParams.requestData) + .toRequestBody("application/json".toMediaType()), + customHeaders = startBotParams.headers.toList(), + timeoutMs = startBotParams.timeoutMs + ) + + postResult.mapError { RTVIError.HttpError(it) }.chain { + try { + resolvedPromiseOk(thread, transport.deserializeConnectParams(it)) + } catch (e: Exception) { + resolvedPromiseErr(thread, RTVIError.ExceptionThrown(e)) + } + }.withCallback { + transport.setState( + if (it.ok) { + TransportState.Authorized + } else { + TransportState.Disconnected + } + ) + } + } + + /** + * Initiate an RTVI session, connecting to the backend. + */ + fun connect(transportParams: ConnectParams): Future = + thread.runOnThreadReturningFuture { + + if (connection != null) { + return@runOnThreadReturningFuture resolvedPromiseErr( + thread, + RTVIError.PreviousConnectionStillActive + ) + } + + connection = Connection() + return@runOnThreadReturningFuture transport.connect(transportParams) + } + + /** + * Performs bot start request and connection in a single operation. + * + * This convenience method combines `startBot()` and `connect()` into a single call, + * handling the complete flow from authentication to established connection. + */ + fun startBotAndConnect(startBotParams: APIRequest): Future = + startBot(startBotParams).chain { connect(it) } + + /** + * Disconnect an active RTVI session. + * + * @return A Future, representing the asynchronous result of this operation. + */ + fun disconnect(): Future { + return transport.disconnect() + } + + /** + * Directly send a message to the bot via the transport. + */ + private fun sendMessage(msg: MsgClientToServer) = transport.sendMessage(msg) + + /** + * Sends a one-way message to the bot without expecting a response. + * + * Use this method to send fire-and-forget messages or notifications to the bot. + */ + fun sendClientMessage(msgType: String, data: Value = Value.Null): Future = + sendMessage( + MsgClientToServer.ClientMessage( + id = UUID.randomUUID().toString(), + msgType = msgType, + data = JSON_INSTANCE.encodeToJsonElement(data) + ) + ) + + /** + * Sends a request message to the bot and waits for a response. + * + * Use this method for request-response communication patterns with the bot. + */ + fun sendClientRequest( + msgType: String, + data: Value = Value.Null + ): Future = thread.runOnThreadReturningFuture { + + val idUuid = UUID.randomUUID() + val id = idUuid.toString() + + val future = responseWaiters.waitFor(id) + + sendMessage( + MsgClientToServer.ClientMessage( + id = id, + msgType = msgType, + data = JSON_INSTANCE.encodeToJsonElement(data) + ) + ) + .withErrorCallback { responseWaiters.reject(id, it) } + .chain { future } + .mapToResult { catchExceptions { JSON_INSTANCE.decodeFromJsonElement(it) } } + } + + /** + * Appends a message to the bot's LLM conversation context. + * + * This method programmatically adds a message to the Large Language Model's conversation + * history, allowing you to inject user context, assistant responses, or other relevant + * information that will influence the bot's subsequent responses. + * + * The context message becomes part of the LLM's memory for the current session and will be + * considered when generating future responses. + */ + fun appendToContext( + message: LLMContextMessage + ): Future = thread.runOnThreadReturningFuture { + val idUuid = UUID.randomUUID() + val id = idUuid.toString() + + val future = responseWaiters.waitFor(id) + + sendMessage(MsgClientToServer.AppendToContext(message)) + .withErrorCallback { responseWaiters.reject(id, it) } + .chain { future } + .mapToResult { catchExceptions { JSON_INSTANCE.decodeFromJsonElement(it) } } + } + + /** + * Registers a function call handler for a specific function name. + * + * When the bot calls a function with the specified name, the registered callback + * will be invoked instead of the delegate's `onLLMFunctionCall` method. + */ + fun registerFunctionCallHandler( + functionName: String, + callback: LLMFunctionCallHandler + ) = thread.runOnThread { + functionCallHandlers[functionName] = callback + } + + /** + * Unregisters a function call handler for a specific function name. + */ + fun unregisterFunctionCallHandler( + functionName: String, + ) = thread.runOnThread { + functionCallHandlers.remove(functionName) + } + + /** + * Unregisters all function call handlers. + */ + fun unregisterFunctionCallHandler() = thread.runOnThread { + functionCallHandlers.clear() + } + + /** + * Sends a disconnect signal to the bot while maintaining the transport connection. + * + * This method instructs the bot to gracefully end the current conversation session + * and clean up its internal state, but keeps the underlying transport connection + * (WebRTC, WebSocket, etc.) active. This is different from `disconnect()` which + * closes the entire connection. + */ + fun disconnectBot() = thread.runOnThread { + sendMessage(MsgClientToServer.DisconnectBot()).logError(TAG, "disconnectBot") + } + + /** + * The current state of the session. + */ + val state + get() = transport.state() + + /** + * Returns a list of available audio input devices. + */ + fun getAllMics() = transport.getAllMics() + + /** + * Returns a list of available video input devices. + */ + fun getAllCams() = transport.getAllCams() + + /** + * Returns the selected audio input device. + */ + val selectedMic + get() = transport.selectedMic() + + /** + * Returns the selected video input device. + */ + val selectedCam + get() = transport.selectedCam() + + /** + * Use the specified audio input device. + * + * @return A Future, representing the asynchronous result of this operation. + */ + fun updateMic(micId: MediaDeviceId) = transport.updateMic(micId) + + /** + * Use the specified video input device. + * + * @return A Future, representing the asynchronous result of this operation. + */ + fun updateCam(camId: MediaDeviceId) = transport.updateCam(camId) + + /** + * Enables or disables the audio input device. + * + * @return A Future, representing the asynchronous result of this operation. + */ + fun enableMic(enable: Boolean) = transport.enableMic(enable) + + /** + * Enables or disables the video input device. + * + * @return A Future, representing the asynchronous result of this operation. + */ + fun enableCam(enable: Boolean) = transport.enableCam(enable) + + /** + * Returns true if the microphone is enabled, false otherwise. + */ + val isMicEnabled + get() = transport.isMicEnabled() + + /** + * Returns true if the camera is enabled, false otherwise. + */ + val isCamEnabled + get() = transport.isCamEnabled() + + /** + * Returns a list of participant media tracks. + */ + val tracks + get() = transport.tracks() + + /** + * Destroys this PipecatClient and cleans up any allocated resources. + */ + fun release() { + thread.assertCurrent() + responseWaiters.clearAll() + transport.release() + } +} \ No newline at end of file diff --git a/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatClientOptions.kt b/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatClientOptions.kt new file mode 100644 index 0000000..22eba6c --- /dev/null +++ b/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatClientOptions.kt @@ -0,0 +1,26 @@ +package ai.pipecat.client + +/** + * Configuration options when instantiating a [PipecatClient]. + */ +data class PipecatClientOptions( + + /** + * Event callbacks. + */ + val callbacks: PipecatEventCallbacks, + + /** + * Enable the user mic input. + * + * Defaults to true. + */ + val enableMic: Boolean = true, + + /** + * Enable user cam input. + * + * Defaults to false. + */ + val enableCam: Boolean = false, +) \ No newline at end of file diff --git a/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatEventCallbacks.kt b/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatEventCallbacks.kt new file mode 100644 index 0000000..13c8fb4 --- /dev/null +++ b/pipecat-client-android/src/main/java/ai/pipecat/client/PipecatEventCallbacks.kt @@ -0,0 +1,178 @@ +package ai.pipecat.client + +import ai.pipecat.client.transport.MsgServerToClient +import ai.pipecat.client.types.BotLLMSearchResponseData +import ai.pipecat.client.types.BotReadyData +import ai.pipecat.client.types.LLMFunctionCallData +import ai.pipecat.client.types.MediaDeviceInfo +import ai.pipecat.client.types.Participant +import ai.pipecat.client.types.PipecatMetrics +import ai.pipecat.client.types.Tracks +import ai.pipecat.client.types.Transcript +import ai.pipecat.client.types.TransportState +import ai.pipecat.client.types.Value + +/** + * Callbacks invoked when changes occur in the session. + */ +@Suppress("unused") +abstract class PipecatEventCallbacks { + + /** + * Invoked when the underlying transport has connected. + */ + open fun onConnected() {} + + /** + * Invoked when the underlying transport has disconnected. + */ + open fun onDisconnected() {} + + /** + * Invoked when the session state has changed. + */ + open fun onTransportStateChanged(state: TransportState) {} + + /** + * Invoked when the bot has connected to the session. + */ + open fun onBotConnected(participant: Participant) {} + + /** + * Invoked when the bot has indicated it is ready for commands. + */ + open fun onBotReady(data: BotReadyData) {} + + /** + * An error has occurred in the RTVI backend. + */ + abstract fun onBackendError(message: String) + + /** + * Invoked when the bot has disconnected from the session. + */ + open fun onBotDisconnected(participant: Participant) {} + + /** + * Invoked when a participant has joined the session. + */ + open fun onParticipantJoined(participant: Participant) {} + + /** + * Invoked when a participant has left the session. + */ + open fun onParticipantLeft(participant: Participant) {} + + /** + * Invoked when the list of available cameras has changed. + */ + open fun onAvailableCamsUpdated(cams: List) {} + + /** + * Invoked when the list of available microphones has updated. + */ + open fun onAvailableMicsUpdated(mics: List) {} + + /** + * Invoked regularly with the volume of the locally captured audio. + */ + open fun onUserAudioLevel(level: Float) {} + + /** + * Invoked regularly with the audio volume of each remote participant. + */ + open fun onRemoteAudioLevel(level: Float, participant: Participant) {} + + /** + * Invoked when the bot starts talking. + */ + open fun onBotStartedSpeaking() {} + + /** + * Invoked when the bot stops talking. + */ + open fun onBotStoppedSpeaking() {} + + /** + * Invoked when the local user starts talking. + */ + open fun onUserStartedSpeaking() {} + + /** + * Invoked when the local user stops talking. + */ + open fun onUserStoppedSpeaking() {} + + /** + * Invoked when session metrics are received. + */ + open fun onMetrics(data: PipecatMetrics) {} + + /** + * Invoked when user transcript data is available. + */ + open fun onUserTranscript(data: Transcript) {} + + /** + * Invoked when bot transcript data is available. + */ + open fun onBotTranscript(text: String) {} + + /** + * Invoked when the state of the input devices changes. + */ + open fun onInputsUpdated(camera: Boolean, mic: Boolean) {} + + /** + * Invoked when the set of available cam/mic tracks changes. + */ + open fun onTracksUpdated(tracks: Tracks) {} + + /** + * Invoked when text is generated by the bot LLM. + */ + open fun onBotLLMText(data: MsgServerToClient.Data.BotLLMTextData) {} + + /** + * Invoked when text is spoken by the bot. + */ + open fun onBotTTSText(data: MsgServerToClient.Data.BotTTSTextData) {} + + /** + * Invoked when the bot starts generating LLM text. + */ + open fun onBotLLMStarted() {} + + /** + * Invoked when the bot stops generating LLM text. + */ + open fun onBotLLMStopped() {} + + /** + * Invoked when the bot starts generating TTS output. + */ + open fun onBotTTSStarted() {} + + /** + * Invoked when the bot stops generating TTS output. + */ + open fun onBotTTSStopped() {} + + /** + * Invoked when we receive a server message from the bot. + */ + open fun onServerMessage(data: Value) {} + + /** + * Invoked when the bot performs a web search. + */ + open fun onBotLLMSearchResponse(data: BotLLMSearchResponseData) {} + + /** + * Invoked when the bot makes a function call request to the client. + * + * To respond to function calls, register a handler using + * registerFunctionCallHandler(). + */ + open fun onLLMFunctionCall(functionCallData: LLMFunctionCallData) {} +} diff --git a/pipecat-client-android/src/main/java/ai/pipecat/client/RTVIClient.kt b/pipecat-client-android/src/main/java/ai/pipecat/client/RTVIClient.kt deleted file mode 100644 index a9ba39e..0000000 --- a/pipecat-client-android/src/main/java/ai/pipecat/client/RTVIClient.kt +++ /dev/null @@ -1,616 +0,0 @@ -package ai.pipecat.client - -import ai.pipecat.client.helper.RTVIClientHelper -import ai.pipecat.client.helper.RegisteredRTVIClient -import ai.pipecat.client.result.Future -import ai.pipecat.client.result.Promise -import ai.pipecat.client.result.RTVIError -import ai.pipecat.client.result.RTVIException -import ai.pipecat.client.result.Result -import ai.pipecat.client.result.resolvedPromiseErr -import ai.pipecat.client.result.withPromise -import ai.pipecat.client.result.withTimeout -import ai.pipecat.client.transport.AuthBundle -import ai.pipecat.client.transport.MsgClientToServer -import ai.pipecat.client.transport.MsgServerToClient -import ai.pipecat.client.transport.Transport -import ai.pipecat.client.transport.TransportContext -import ai.pipecat.client.transport.TransportFactory -import ai.pipecat.client.types.ActionDescription -import ai.pipecat.client.types.Config -import ai.pipecat.client.types.MediaDeviceId -import ai.pipecat.client.types.Option -import ai.pipecat.client.types.RegisteredHelper -import ai.pipecat.client.types.ServiceConfig -import ai.pipecat.client.types.ServiceConfigDescription -import ai.pipecat.client.types.Transcript -import ai.pipecat.client.types.TransportState -import ai.pipecat.client.types.Value -import ai.pipecat.client.utils.ConnectionBundle -import ai.pipecat.client.utils.JSON_INSTANCE -import ai.pipecat.client.utils.ThreadRef -import ai.pipecat.client.utils.parseServerSentEvents -import ai.pipecat.client.utils.post -import ai.pipecat.client.utils.valueFrom -import android.util.Log -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.decodeFromJsonElement -import kotlinx.serialization.json.jsonObject -import okhttp3.MediaType.Companion.toMediaType -import okhttp3.RequestBody.Companion.toRequestBody - -private const val RTVI_PROTOCOL_VERSION = "0.3.0" - -/** - * An RTVI client. Connects to an RTVI backend and handles bidirectional audio and video - * streaming. - * - * The client must be cleaned up using the [release] method when it is no longer required. - * - * @param transport Transport for media streaming. - * @param callbacks Callbacks invoked when changes occur in the voice session. - * @param options Additional options for configuring the client and backend. - */ -@Suppress("unused") -open class RTVIClient( - transport: TransportFactory, - callbacks: RTVIEventCallbacks, - private var options: RTVIClientOptions, -) { - companion object { - private const val TAG = "RTVIClient" - } - - /** - * The thread used by the VoiceClient for callbacks and other operations. - */ - val thread = ThreadRef.forCurrent() - - private val callbacks = CallbackInterceptor(object : RTVIEventCallbacks() { - override fun onBackendError(message: String) {} - - override fun onDisconnected() { - discardWaitingResponses() - connection?.ready?.resolveErr(RTVIError.OperationCancelled) - connection = null - } - }, callbacks) - - private val helpers = mutableMapOf() - - private val awaitingServerResponse = - mutableMapOf) -> Unit>() - - private inline fun handleResponse( - msg: MsgServerToClient, - action: ((Result) -> Unit) -> Unit - ) { - val id = msg.id ?: throw Exception("${msg.type} missing ID") - - if (id == "END") { - return - } - - val respondTo = awaitingServerResponse.remove(id) - ?: throw Exception("${msg.type}: no responder for $id") - - action(respondTo) - } - - private val transportCtx = object : TransportContext { - - override val options - get() = this@RTVIClient.options - - override val callbacks - get() = this@RTVIClient.callbacks - - override val thread = this@RTVIClient.thread - - override fun onMessage(msg: MsgServerToClient) = thread.runOnThread { - - try { - when (msg.type) { - MsgServerToClient.Type.BotReady -> { - - val data = - JSON_INSTANCE.decodeFromJsonElement(msg.data) - - this@RTVIClient.transport.setState(TransportState.Ready) - - connection?.ready?.resolveOk(Unit) - - callbacks.onBotReady( - version = data.version, - config = data.config - ) - } - - MsgServerToClient.Type.Error -> { - val data = - JSON_INSTANCE.decodeFromJsonElement(msg.data) - callbacks.onBackendError(data.error) - } - - MsgServerToClient.Type.ErrorResponse -> { - - val data = - JSON_INSTANCE.decodeFromJsonElement( - msg.data - ) - - try { - handleResponse(msg) { respondTo -> - respondTo(Result.Err(RTVIError.ErrorResponse(data.error))) - } - } catch (e: Exception) { - Log.e(TAG, "Got exception handling error response", e) - callbacks.onBackendError(data.error) - } - } - - MsgServerToClient.Type.ActionResponse, - MsgServerToClient.Type.DescribeActionsResponse, - MsgServerToClient.Type.DescribeConfigResponse, - MsgServerToClient.Type.GetOrUpdateConfigResponse -> { - handleResponse(msg) { respondTo -> - respondTo(Result.Ok(msg.data)) - } - } - - MsgServerToClient.Type.UserTranscription -> { - val data = JSON_INSTANCE.decodeFromJsonElement(msg.data) - callbacks.onUserTranscript(data) - } - - MsgServerToClient.Type.BotTranscription, - MsgServerToClient.Type.BotTranscriptionLegacy -> { - val text = (msg.data.jsonObject.get("text") as JsonPrimitive).content - callbacks.onBotTranscript(text) - } - - MsgServerToClient.Type.UserStartedSpeaking -> { - callbacks.onUserStartedSpeaking() - } - - MsgServerToClient.Type.UserStoppedSpeaking -> { - callbacks.onUserStoppedSpeaking() - } - - MsgServerToClient.Type.BotStartedSpeaking -> { - callbacks.onBotStartedSpeaking() - } - - MsgServerToClient.Type.BotStoppedSpeaking -> { - callbacks.onBotStoppedSpeaking() - } - - MsgServerToClient.Type.BotLlmText -> { - val data: MsgServerToClient.Data.BotLLMTextData = - JSON_INSTANCE.decodeFromJsonElement(msg.data) - - callbacks.onBotLLMText(data) - } - - MsgServerToClient.Type.BotTtsText -> { - val data: MsgServerToClient.Data.BotTTSTextData = - JSON_INSTANCE.decodeFromJsonElement(msg.data) - - callbacks.onBotTTSText(data) - } - - MsgServerToClient.Type.BotLlmStarted -> callbacks.onBotLLMStarted() - MsgServerToClient.Type.BotLlmStopped -> callbacks.onBotLLMStopped() - - MsgServerToClient.Type.BotTtsStarted -> callbacks.onBotTTSStarted() - MsgServerToClient.Type.BotTtsStopped -> callbacks.onBotTTSStopped() - - MsgServerToClient.Type.StorageItemStored -> { - val data: MsgServerToClient.Data.StorageItemStoredData = - JSON_INSTANCE.decodeFromJsonElement(msg.data) - - callbacks.onStorageItemStored(data) - } - - MsgServerToClient.Type.ServerMessage -> { - callbacks.onServerMessage(JSON_INSTANCE.decodeFromJsonElement(msg.data)) - } - - else -> { - - var match = false - - helpers.values - .filter { it.supportedMessages.contains(msg.type) } - .forEach { entry -> - match = true - entry.helper.handleMessage(msg) - } - - if (!match) { - Log.w(TAG, "Unexpected message type '${msg.type}'") - - callbacks.onGenericMessage(msg) - } - } - } - } catch (e: Exception) { - Log.e(TAG, "Exception while handling VoiceMessage", e) - } - } - } - - private val transport: Transport = transport.createTransport(transportCtx) - - private inner class Connection { - val ready = Promise(thread) - } - - private var connection: Connection? = null - - /** - * Initialize local media devices such as camera and microphone. - * - * @return A Future, representing the asynchronous result of this operation. - */ - fun initDevices(): Future = transport.initDevices() - - /** - * Initiate an RTVI session, connecting to the backend. - */ - fun connect(): Future = thread.runOnThreadReturningFuture { - - if (connection != null) { - return@runOnThreadReturningFuture resolvedPromiseErr( - thread, - RTVIError.PreviousConnectionStillActive - ) - } - - if (options.params.baseUrl != null && options.params.endpoints.connect != null) { - - transport.setState(TransportState.Authorizing) - - // Send POST request to the provided base_url to connect and start the bot - - val connectionData = ConnectionData.from(options) - - val body = ConnectionBundle( - services = options.services?.associate { it.service to it.value }, - config = connectionData.config - ) - .serializeWithCustomParams(connectionData.requestData) - .toRequestBody("application/json".toMediaType()) - - val currentConnection = Connection().apply { connection = this } - - return@runOnThreadReturningFuture post( - thread = thread, - url = options.params.baseUrl + options.params.endpoints.connect, - body = body, - customHeaders = connectionData.headers - ) - .mapError { - RTVIError.HttpError(it) - } - .chain { authBundle -> - if (currentConnection == connection) { - transport.connect(AuthBundle(authBundle)) - } else { - resolvedPromiseErr(thread, RTVIError.OperationCancelled) - } - } - .chain { currentConnection.ready } - .withTimeout(30000) - .withErrorCallback { - disconnect() - } - - } else { - // No connection endpoint - Log.w(TAG, "No connect endpoint specified, skipping auth request") - connection = Connection() - return@runOnThreadReturningFuture transport.connect(null) - } - } - - /** - * Disconnect an active RTVI session. - * - * @return A Future, representing the asynchronous result of this operation. - */ - fun disconnect(): Future { - return transport.disconnect() - } - - /** - * Directly send a message to the bot via the transport. - */ - fun sendMessage(msg: MsgClientToServer) = transport.sendMessage(msg) - - private fun discardWaitingResponses() { - - thread.assertCurrent() - - awaitingServerResponse.values.forEach { - it(Result.Err(RTVIError.OperationCancelled)) - } - - awaitingServerResponse.clear() - } - - /** - * Registers a new helper with the client. - * - * @param service Target service for this helper - * @param helper Helper instance - */ - @Throws(RTVIException::class) - fun registerHelper(service: String, helper: E): E { - - thread.assertCurrent() - - if (helpers.containsKey(service)) { - throw RTVIException(RTVIError.OtherError("Helper targeting service '$service' already registered")) - } - - helper.registerVoiceClient(RegisteredRTVIClient(this, service)) - - val entry = RegisteredHelper( - helper = helper, - supportedMessages = HashSet(helper.getMessageTypes()) - ) - - helpers[service] = entry - - return helper - } - - /** - * Unregisters a helper from the client. - */ - @Throws(RTVIException::class) - fun unregisterHelper(service: String) { - - thread.assertCurrent() - - val entry = helpers.remove(service) - ?: throw RTVIException(RTVIError.OtherError("Helper targeting service '$service' not found")) - - entry.helper.unregisterVoiceClient() - } - - private inline fun sendWithResponse( - msg: MsgClientToServer, - allowSingleTurn: Boolean = false, - crossinline filter: (M) -> R - ) = withPromise(thread) { promise -> - thread.runOnThread { - - awaitingServerResponse[msg.id] = { result -> - when (result) { - is Result.Err -> promise.resolveErr(result.error) - is Result.Ok -> { - val data = JSON_INSTANCE.decodeFromJsonElement(result.value) - promise.resolveOk(filter(data)) - } - } - } - - when (transport.state()) { - TransportState.Connected, TransportState.Ready -> { - transport.sendMessage(msg) - .withTimeout(10000) - .withErrorCallback { - awaitingServerResponse.remove(msg.id) - promise.resolveErr(it) - } - } - - else -> if (allowSingleTurn) { - - val connectionData = ConnectionData.from(options) - - post( - thread = thread, - url = options.params.baseUrl + options.params.endpoints.action, - body = JSON_INSTANCE.encodeToString( - Value.serializer(), Value.Object( - (connectionData.requestData + listOf( - "actions" to Value.Array( - valueFrom(MsgClientToServer.serializer(), msg) - ) - )).toMap() - ) - ).toRequestBody("application/json".toMediaType()), - customHeaders = connectionData.headers, - responseHandler = { inputStream -> - inputStream.parseServerSentEvents { msg -> - transportCtx.onMessage(JSON_INSTANCE.decodeFromString(msg)) - } - } - ).withCallback { - promise.resolveErr( - when (it) { - is Result.Err -> RTVIError.HttpError(it.error) - is Result.Ok -> RTVIError.OtherError("Connection ended before result received") - } - ) - } - } - } - } - } - - /** - * Instruct a backend service to perform an action. - */ - fun action( - service: String, - action: String, - arguments: List