11package com.openai.helpers
22
3- import com.openai.core.JsonField
43import com.openai.core.JsonNull
54import com.openai.core.JsonValue
65import com.openai.errors.OpenAIInvalidDataException
@@ -68,6 +67,27 @@ class ChatCompletionAccumulator private constructor() {
6867 */
6968 private val logprobsBuilders = mutableMapOf<Long , ChatCompletion .Choice .Logprobs .Builder >()
7069
70+ /* *
71+ * The accumulated tool call builders for each message. The "outer" keys correspond to the
72+ * indexes in [messageBuilders] (the choice index). The "inner" keys correspond to the position
73+ * of each tool call in the message's list of tool calls (the tool call index).
74+ */
75+ private val toolCallBuilders =
76+ mutableMapOf<Long , MutableMap <Long , ChatCompletionMessageToolCall .Builder >>()
77+
78+ /* *
79+ * The accumulated tool call function builders for the tool call builders of each message. The
80+ * entries correspond to those in [toolCallBuilders].
81+ */
82+ private val toolCallFunctionBuilders =
83+ mutableMapOf<Long , MutableMap <Long , ChatCompletionMessageToolCall .Function .Builder >>()
84+
85+ /* *
86+ * The accumulated tool call function arguments that will be set on the function builders when
87+ * completed. The entries correspond to those in [toolCallFunctionBuilders].
88+ */
89+ private val toolCallFunctionArgs = mutableMapOf<Long , MutableMap <Long , String >>()
90+
7191 /* *
7292 * The finished status of each of the `n` completions. When a chunk with a `finishReason` is
7393 * encountered, its index is recorded against a `true` value. When a `true` has been recorded
@@ -80,27 +100,6 @@ class ChatCompletionAccumulator private constructor() {
80100 companion object {
81101 @JvmStatic fun create () = ChatCompletionAccumulator ()
82102
83- @JvmSynthetic
84- internal fun convertToolCall (chunkToolCall : ChatCompletionChunk .Choice .Delta .ToolCall ) =
85- ChatCompletionMessageToolCall .builder()
86- .id(chunkToolCall._id ())
87- .function(convertToolCallFunction(chunkToolCall._function ()))
88- .additionalProperties(chunkToolCall._additionalProperties ())
89- // Let the `type` default to "function".
90- .build()
91-
92- @JvmSynthetic
93- internal fun convertToolCallFunction (
94- chunkToolCallFunction : JsonField <ChatCompletionChunk .Choice .Delta .ToolCall .Function >
95- ): JsonField <ChatCompletionMessageToolCall .Function > =
96- chunkToolCallFunction.map { function ->
97- ChatCompletionMessageToolCall .Function .builder()
98- .name(function._name ())
99- .arguments(function._arguments ())
100- .additionalProperties(function._additionalProperties ())
101- .build()
102- }
103-
104103 @JvmSynthetic
105104 internal fun convertFunctionCall (
106105 chunkFunctionCall : ChatCompletionChunk .Choice .Delta .FunctionCall
@@ -253,14 +252,48 @@ class ChatCompletionAccumulator private constructor() {
253252 delta.role().ifPresent { messageBuilder.role(JsonValue .from(it.asString())) }
254253 delta.functionCall().ifPresent { messageBuilder.functionCall(convertFunctionCall(it)) }
255254
256- // Add the `ToolCall` objects in the order in which they are encountered.
257- // (`...Delta.ToolCall.index` is not documented, so it is ignored here.)
258- delta.toolCalls().ifPresent { it.map { messageBuilder.addToolCall(convertToolCall(it)) } }
255+ delta.toolCalls().ifPresent {
256+ it.map { deltaToolCall ->
257+ // The first chunk delta will carry the tool call ID and the function name. Later
258+ // deltas will carry only fragments of the function arguments, but the tool call
259+ // index will identify the function to which those argument fragments belong.
260+ val messageToolCallBuilders = toolCallBuilders.getOrPut(index) { mutableMapOf () }
261+
262+ messageToolCallBuilders.getOrPut(deltaToolCall.index()) {
263+ ChatCompletionMessageToolCall .builder()
264+ .id(deltaToolCall._id ())
265+ .additionalProperties(deltaToolCall._additionalProperties ())
266+ // Must wait until the `function` is accumulated and built before adding it to
267+ // the tool call later when `buildChoices` is called.
268+ }
269+
270+ val messageToolCallFunctionBuilders =
271+ toolCallFunctionBuilders.getOrPut(index) { mutableMapOf () }
272+
273+ messageToolCallFunctionBuilders.getOrPut(deltaToolCall.index()) {
274+ ChatCompletionMessageToolCall .Function .builder()
275+ .name(ensureFunction(deltaToolCall.function())._name ())
276+ .additionalProperties(deltaToolCall._additionalProperties ())
277+ }
278+
279+ val messageToolCallFunctionArgs =
280+ toolCallFunctionArgs.getOrPut(index) { mutableMapOf () }
281+
282+ messageToolCallFunctionArgs[deltaToolCall.index()] =
283+ (messageToolCallFunctionArgs[deltaToolCall.index()] ? : " " ) +
284+ (ensureFunction(deltaToolCall.function()).arguments().getOrNull() ? : " " )
285+ }
286+ }
287+
259288 messageBuilder.putAllAdditionalProperties(delta._additionalProperties ())
260289 }
261290
262- @JvmSynthetic
263- internal fun buildChoices () =
291+ private fun ensureFunction (
292+ function : Optional <ChatCompletionChunk .Choice .Delta .ToolCall .Function >
293+ ): ChatCompletionChunk .Choice .Delta .ToolCall .Function =
294+ function.orElseThrow { OpenAIInvalidDataException (" Tool call chunk missing function." ) }
295+
296+ private fun buildChoices () =
264297 choiceBuilders.entries
265298 .sortedBy { it.key }
266299 .map {
@@ -270,13 +303,41 @@ class ChatCompletionAccumulator private constructor() {
270303 .build()
271304 }
272305
273- @JvmSynthetic
274- internal fun buildMessage (index : Long ) =
306+ private fun buildMessage (index : Long ) =
275307 messageBuilders
276308 .getOrElse(index) {
277309 throw OpenAIInvalidDataException (" Missing message for index $index ." )
278310 }
279311 .content(messageContents[index])
280312 .refusal(messageRefusals[index])
313+ .toolCalls(buildToolCalls(index))
281314 .build()
315+
316+ private fun buildToolCalls (index : Long ): List <ChatCompletionMessageToolCall > =
317+ // It is OK for a message not to have any tool calls; most will not and an empty list will
318+ // be returned. An entry (if it exists) will be a collection of tool call builders and each
319+ // has a function that needs to be set.
320+ toolCallBuilders[index]
321+ ?.entries
322+ ?.sortedBy { it.key }
323+ ?.map { messageToolCallBuilderEntry ->
324+ messageToolCallBuilderEntry.value
325+ .function(buildFunction(index, messageToolCallBuilderEntry.key))
326+ .build()
327+ } ? : listOf ()
328+
329+ private fun buildFunction (index : Long , toolCallIndex : Long ) =
330+ // Every tool call is expected to have a function with arguments.
331+ toolCallFunctionBuilders[index]
332+ ?.get(toolCallIndex)
333+ ?.arguments(
334+ toolCallFunctionArgs[index]?.get(toolCallIndex)
335+ ? : throw OpenAIInvalidDataException (
336+ " Missing function arguments for index $index .$toolCallIndex ."
337+ )
338+ )
339+ ?.build()
340+ ? : throw OpenAIInvalidDataException (
341+ " Missing function builder for index $index .$toolCallIndex ."
342+ )
282343}
0 commit comments