From f3e1c6494df23f6c1a70987ca7ee79c1965c933f Mon Sep 17 00:00:00 2001 From: D Gardner Date: Fri, 2 May 2025 15:24:11 +0100 Subject: [PATCH 1/9] structured-outputs: updates and more unit tests. --- openai-java-core/build.gradle.kts | 2 + .../com/openai/core/StructuredOutputs.kt | 67 + .../completions/ChatCompletionCreateParams.kt | 8 + .../completions/StructuredChatCompletion.kt | 169 ++ .../StructuredChatCompletionCreateParams.kt | 744 +++++++++ .../StructuredChatCompletionMessage.kt | 92 ++ .../blocking/chat/ChatCompletionService.kt | 12 + .../openai/core/JsonSchemaValidatorTest.kt | 1403 +++++++++++++++++ .../ChatCompletionCreateParamsTest.kt | 32 + ...tructuredChatCompletionCreateParamsTest.kt | 499 ++++++ .../StructuredChatCompletionMessageTest.kt | 141 ++ .../StructuredChatCompletionTest.kt | 405 +++++ .../StructuredOutputsClassExample.java | 73 + 13 files changed, 3647 insertions(+) create mode 100644 openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt create mode 100644 openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java diff --git a/openai-java-core/build.gradle.kts b/openai-java-core/build.gradle.kts index 08a91e0d..894f0e23 100644 --- a/openai-java-core/build.gradle.kts +++ b/openai-java-core/build.gradle.kts @@ -27,6 +27,8 @@ dependencies { implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.18.2") implementation("org.apache.httpcomponents.core5:httpcore5:5.2.4") implementation("org.apache.httpcomponents.client5:httpclient5:5.3.1") + implementation("com.github.victools:jsonschema-generator:4.38.0") + implementation("com.github.victools:jsonschema-module-jackson:4.38.0") testImplementation(kotlin("test")) testImplementation(project(":openai-java-client-okhttp")) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt new file mode 100644 index 00000000..7f18d237 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -0,0 +1,67 @@ +package com.openai.core + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule +import com.fasterxml.jackson.module.kotlin.kotlinModule +import com.github.victools.jsonschema.generator.Option +import com.github.victools.jsonschema.generator.OptionPreset +import com.github.victools.jsonschema.generator.SchemaGenerator +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder +import com.github.victools.jsonschema.module.jackson.JacksonModule +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.ResponseFormatJsonSchema + +// The SDK `ObjectMappers.jsonMapper()` requires that all fields of classes be marked with +// `@JsonProperty`, which is not desirable in this context, as it impedes usability. Therefore, a +// custom JSON mapper configuration is required. +private val MAPPER = + JsonMapper.builder() + .addModule(kotlinModule()) + .addModule(Jdk8Module()) + .addModule(JavaTimeModule()) + .build() + +fun fromClass(type: Class) = + ResponseFormatJsonSchema.builder() + .jsonSchema( + ResponseFormatJsonSchema.JsonSchema.builder() + .name("json-schema-from-${type.simpleName}") + .schema(JsonValue.from(extractSchema(type))) + .build() + ) + .build() + +internal fun extractSchema(type: Class): JsonNode { + val configBuilder = + SchemaGeneratorConfigBuilder( + com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON, + ) + // Add `"additionalProperties" : false` to all object schemas (see OpenAI). + .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT) + // Use `JacksonModule` to support the use of Jackson annotations to set property and + // class names and descriptions and to mark fields with `@JsonIgnore`. + .with(JacksonModule()) + + configBuilder + .forFields() + // For OpenAI schemas, _all_ properties _must_ be required. Override the interpretation of + // the Jackson `required` parameter to the `@JsonProperty` annotation: it will always be + // assumed to be `true`, even if explicitly `false` and even if there is no `@JsonProperty` + // annotation present. + .withRequiredCheck { true } + + return SchemaGenerator(configBuilder.build()).generateSchema(type) +} + +fun fromJson(json: String, type: Class): T = + try { + MAPPER.readValue(json, type) + } catch (e: Exception) { + // The JSON document is included in the exception message to aid diagnosis of the problem. + // It is the responsibility of the SDK user to ensure that exceptions that may contain + // sensitive data are not exposed in logs. + throw OpenAIInvalidDataException("Error parsing JSON: $json", e) + } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index a3281dc6..cb3459fe 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -1297,6 +1297,14 @@ private constructor( body.responseFormat(jsonObject) } + /** + * Sets the class that defines the structured outputs response format. This changes the + * builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that will build a + * [StructuredChatCompletionCreateParams] instance when `build()` is called. + */ + fun responseFormat(responseFormat: Class) = + StructuredChatCompletionCreateParams.builder().wrap(responseFormat, this) + /** * This feature is in Beta. If specified, our system will make a best effort to sample * deterministically, such that repeated requests with the same `seed` and parameters should diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt new file mode 100644 index 00000000..6ca931a5 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt @@ -0,0 +1,169 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.chat.completions.ChatCompletion.Choice.FinishReason +import com.openai.models.chat.completions.ChatCompletion.Choice.Logprobs +import com.openai.models.chat.completions.ChatCompletion.ServiceTier +import com.openai.models.completions.CompletionUsage +import java.util.Objects +import java.util.Optional + +class StructuredChatCompletion( + val responseFormat: Class, + val chatCompletion: ChatCompletion, +) { + /** @see ChatCompletion.id */ + fun id(): String = chatCompletion.id() + + private val choices by lazy { + chatCompletion._choices().map { choices -> choices.map { Choice(responseFormat, it) } } + } + + /** @see ChatCompletion.choices */ + fun choices(): List> = choices.getRequired("choices") + + /** @see ChatCompletion.created */ + fun created(): Long = chatCompletion.created() + + /** @see ChatCompletion.model */ + fun model(): String = chatCompletion.model() + + /** @see ChatCompletion._object_ */ + fun _object_(): JsonValue = chatCompletion._object_() + + /** @see ChatCompletion.serviceTier */ + fun serviceTier(): Optional = chatCompletion.serviceTier() + + /** @see ChatCompletion.systemFingerprint */ + fun systemFingerprint(): Optional = chatCompletion.systemFingerprint() + + /** @see ChatCompletion.usage */ + fun usage(): Optional = chatCompletion.usage() + + /** @see ChatCompletion._id */ + fun _id(): JsonField = chatCompletion._id() + + /** @see ChatCompletion._choices */ + fun _choices(): JsonField>> = choices + + /** @see ChatCompletion._created */ + fun _created(): JsonField = chatCompletion._created() + + /** @see ChatCompletion._model */ + fun _model(): JsonField = chatCompletion._model() + + /** @see ChatCompletion._serviceTier */ + fun _serviceTier(): JsonField = chatCompletion._serviceTier() + + /** @see ChatCompletion._systemFingerprint */ + fun _systemFingerprint(): JsonField = chatCompletion._systemFingerprint() + + /** @see ChatCompletion._usage */ + fun _usage(): JsonField = chatCompletion._usage() + + /** @see ChatCompletion._additionalProperties */ + fun _additionalProperties(): Map = chatCompletion._additionalProperties() + + class Choice + internal constructor( + internal val responseFormat: Class, + internal val choice: ChatCompletion.Choice, + ) { + /** @see ChatCompletion.Choice.finishReason */ + fun finishReason(): FinishReason = choice.finishReason() + + /** @see ChatCompletion.Choice.index */ + fun index(): Long = choice.index() + + /** @see ChatCompletion.Choice.logprobs */ + fun logprobs(): Optional = choice.logprobs() + + /** @see ChatCompletion.Choice._finishReason */ + fun _finishReason(): JsonField = choice._finishReason() + + private val message by lazy { + choice._message().map { StructuredChatCompletionMessage(responseFormat, it) } + } + + /** @see ChatCompletion.Choice.message */ + fun message(): StructuredChatCompletionMessage = message.getRequired("message") + + /** @see ChatCompletion.Choice._index */ + fun _index(): JsonField = choice._index() + + /** @see ChatCompletion.Choice._logprobs */ + fun _logprobs(): JsonField = choice._logprobs() + + /** @see ChatCompletion.Choice._message */ + fun _message(): JsonField> = message + + /** @see ChatCompletion.Choice._additionalProperties */ + fun _additionalProperties(): Map = choice._additionalProperties() + + /** @see ChatCompletion.Choice.validate */ + fun validate(): Choice = apply { + message().validate() + choice.validate() + } + + /** @see ChatCompletion.Choice.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is Choice<*> && + responseFormat == other.responseFormat && + choice == other.choice + } + + private val hashCode: Int by lazy { Objects.hash(responseFormat, choice) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, choice=$choice}" + } + + /** @see ChatCompletion.validate */ + fun validate() = apply { + choices().forEach { it.validate() } + chatCompletion.validate() + } + + /** @see ChatCompletion.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletion<*> && + responseFormat == other.responseFormat && + chatCompletion == other.chatCompletion + } + + private val hashCode: Int by lazy { Objects.hash(responseFormat, chatCompletion) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, chatCompletion=$chatCompletion}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt new file mode 100644 index 00000000..ae1ea1be --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -0,0 +1,744 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.checkRequired +import com.openai.core.fromClass +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.models.ChatModel +import com.openai.models.ReasoningEffort +import java.util.Objects +import java.util.Optional + +class StructuredChatCompletionCreateParams +internal constructor( + val responseFormat: Class, + /** + * The raw, underlying chat completion create parameters wrapped by this structured instance of + * the parameters. + */ + @get:JvmName("rawParams") val rawParams: ChatCompletionCreateParams, +) { + + companion object { + @JvmStatic fun builder() = Builder() + } + + class Builder internal constructor() { + private var responseFormat: Class? = null + private var paramsBuilder = ChatCompletionCreateParams.builder() + + @JvmSynthetic + internal fun wrap( + responseFormat: Class, + paramsBuilder: ChatCompletionCreateParams.Builder, + ) = apply { + this.responseFormat = responseFormat + this.paramsBuilder = paramsBuilder + // Convert the class to a JSON schema and apply it to the delegate `Builder`. + responseFormat(responseFormat) + } + + /** Injects a given `ChatCompletionCreateParams.Builder`. For use only when testing. */ + @JvmSynthetic + internal fun inject(paramsBuilder: ChatCompletionCreateParams.Builder) = apply { + this.paramsBuilder = paramsBuilder + } + + /** @see ChatCompletionCreateParams.Builder.body */ + fun body(body: ChatCompletionCreateParams.Body) = apply { paramsBuilder.body(body) } + + /** @see ChatCompletionCreateParams.Builder.messages */ + fun messages(messages: List) = apply { + paramsBuilder.messages(messages) + } + + /** @see ChatCompletionCreateParams.Builder.messages */ + fun messages(messages: JsonField>) = apply { + paramsBuilder.messages(messages) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(message: ChatCompletionMessageParam) = apply { + paramsBuilder.addMessage(message) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(developer: ChatCompletionDeveloperMessageParam) = apply { + paramsBuilder.addMessage(developer) + } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessage */ + fun addDeveloperMessage(content: ChatCompletionDeveloperMessageParam.Content) = apply { + paramsBuilder.addDeveloperMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessage */ + fun addDeveloperMessage(text: String) = apply { paramsBuilder.addDeveloperMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessageOfArrayOfContentParts */ + fun addDeveloperMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addDeveloperMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(system: ChatCompletionSystemMessageParam) = apply { + paramsBuilder.addMessage(system) + } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessage */ + fun addSystemMessage(content: ChatCompletionSystemMessageParam.Content) = apply { + paramsBuilder.addSystemMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessage */ + fun addSystemMessage(text: String) = apply { paramsBuilder.addSystemMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessageOfArrayOfContentParts */ + fun addSystemMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addSystemMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(user: ChatCompletionUserMessageParam) = apply { + paramsBuilder.addMessage(user) + } + + /** @see ChatCompletionCreateParams.Builder.addUserMessage */ + fun addUserMessage(content: ChatCompletionUserMessageParam.Content) = apply { + paramsBuilder.addUserMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addUserMessage */ + fun addUserMessage(text: String) = apply { paramsBuilder.addUserMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addUserMessageOfArrayOfContentParts */ + fun addUserMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addUserMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(assistant: ChatCompletionAssistantMessageParam) = apply { + paramsBuilder.addMessage(assistant) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(assistant: ChatCompletionMessage) = apply { + paramsBuilder.addMessage(assistant) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(tool: ChatCompletionToolMessageParam) = apply { + paramsBuilder.addMessage(tool) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + @Deprecated("deprecated") + fun addMessage(function: ChatCompletionFunctionMessageParam) = apply { + paramsBuilder.addMessage(function) + } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(model: ChatModel) = apply { paramsBuilder.model(model) } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(model: JsonField) = apply { paramsBuilder.model(model) } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(value: String) = apply { paramsBuilder.model(value) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: ChatCompletionAudioParam?) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: Optional) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: JsonField) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Double?) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Double) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Optional) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: JsonField) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCall: ChatCompletionCreateParams.FunctionCall) = apply { + paramsBuilder.functionCall(functionCall) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCall: JsonField) = apply { + paramsBuilder.functionCall(functionCall) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(mode: ChatCompletionCreateParams.FunctionCall.FunctionCallMode) = apply { + paramsBuilder.functionCall(mode) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCallOption: ChatCompletionFunctionCallOption) = apply { + paramsBuilder.functionCall(functionCallOption) + } + + /** @see ChatCompletionCreateParams.Builder.functions */ + @Deprecated("deprecated") + fun functions(functions: List) = apply { + paramsBuilder.functions(functions) + } + + /** @see ChatCompletionCreateParams.Builder.functions */ + @Deprecated("deprecated") + fun functions(functions: JsonField>) = apply { + paramsBuilder.functions(functions) + } + + /** @see ChatCompletionCreateParams.Builder.addFunction */ + @Deprecated("deprecated") + fun addFunction(function: ChatCompletionCreateParams.Function) = apply { + paramsBuilder.addFunction(function) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: ChatCompletionCreateParams.LogitBias?) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: Optional) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: JsonField) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Boolean?) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Boolean) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Optional) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: JsonField) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Long?) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Long) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Optional) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: JsonField) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Long?) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Long) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Optional) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: JsonField) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: ChatCompletionCreateParams.Metadata?) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: Optional) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: JsonField) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: List?) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: Optional>) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: JsonField>) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.addModality */ + fun addModality(modality: ChatCompletionCreateParams.Modality) = apply { + paramsBuilder.addModality(modality) + } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Long?) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Long) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Optional) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: JsonField) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ChatCompletionCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: JsonField) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: ChatCompletionPredictionContent?) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: Optional) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: JsonField) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Double?) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Double) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Optional) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: JsonField) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: ReasoningEffort?) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: Optional) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: JsonField) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** Sets the response format to a JSON schema derived from the given class. */ + fun responseFormat(responseFormat: Class) = apply { + this.responseFormat = responseFormat + paramsBuilder.responseFormat(fromClass(responseFormat)) + } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Long?) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Long) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Optional) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: JsonField) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: ChatCompletionCreateParams.ServiceTier?) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: Optional) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: JsonField) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: ChatCompletionCreateParams.Stop?) = apply { paramsBuilder.stop(stop) } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: Optional) = apply { + paramsBuilder.stop(stop) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: JsonField) = apply { + paramsBuilder.stop(stop) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(string: String) = apply { paramsBuilder.stop(string) } + + /** @see ChatCompletionCreateParams.Builder.stopOfStrings */ + fun stopOfStrings(strings: List) = apply { paramsBuilder.stopOfStrings(strings) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Boolean?) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Boolean) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Optional) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: JsonField) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: ChatCompletionStreamOptions?) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: Optional) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: JsonField) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Double?) = apply { paramsBuilder.temperature(temperature) } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Double) = apply { paramsBuilder.temperature(temperature) } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Optional) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: JsonField) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: ChatCompletionToolChoiceOption) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: JsonField) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(auto: ChatCompletionToolChoiceOption.Auto) = apply { + paramsBuilder.toolChoice(auto) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(namedToolChoice: ChatCompletionNamedToolChoice) = apply { + paramsBuilder.toolChoice(namedToolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.tools */ + fun tools(tools: List) = apply { paramsBuilder.tools(tools) } + + /** @see ChatCompletionCreateParams.Builder.tools */ + fun tools(tools: JsonField>) = apply { paramsBuilder.tools(tools) } + + /** @see ChatCompletionCreateParams.Builder.addTool */ + fun addTool(tool: ChatCompletionTool) = apply { paramsBuilder.addTool(tool) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Long?) = apply { paramsBuilder.topLogprobs(topLogprobs) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Long) = apply { paramsBuilder.topLogprobs(topLogprobs) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Optional) = apply { + paramsBuilder.topLogprobs(topLogprobs) + } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: JsonField) = apply { + paramsBuilder.topLogprobs(topLogprobs) + } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Double?) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Double) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Optional) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: JsonField) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.user */ + fun user(user: String) = apply { paramsBuilder.user(user) } + + /** @see ChatCompletionCreateParams.Builder.user */ + fun user(user: JsonField) = apply { paramsBuilder.user(user) } + + /** @see ChatCompletionCreateParams.Builder.webSearchOptions */ + fun webSearchOptions(webSearchOptions: ChatCompletionCreateParams.WebSearchOptions) = + apply { + paramsBuilder.webSearchOptions(webSearchOptions) + } + + /** @see ChatCompletionCreateParams.Builder.webSearchOptions */ + fun webSearchOptions( + webSearchOptions: JsonField + ) = apply { paramsBuilder.webSearchOptions(webSearchOptions) } + + /** @see ChatCompletionCreateParams.Builder.additionalBodyProperties */ + fun additionalBodyProperties(additionalBodyProperties: Map) = apply { + paramsBuilder.additionalBodyProperties(additionalBodyProperties) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalBodyProperty */ + fun putAdditionalBodyProperty(key: String, value: JsonValue) = apply { + paramsBuilder.putAdditionalBodyProperty(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalBodyProperties */ + fun putAllAdditionalBodyProperties(additionalBodyProperties: Map) = + apply { + paramsBuilder.putAllAdditionalBodyProperties(additionalBodyProperties) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalBodyProperty */ + fun removeAdditionalBodyProperty(key: String) = apply { + paramsBuilder.removeAdditionalBodyProperty(key) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalBodyProperties */ + fun removeAllAdditionalBodyProperties(keys: Set) = apply { + paramsBuilder.removeAllAdditionalBodyProperties(keys) + } + + /** @see ChatCompletionCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalHeader */ + fun putAdditionalHeader(name: String, value: String) = apply { + paramsBuilder.putAdditionalHeader(name, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalHeaders */ + fun putAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.putAdditionalHeaders(name, values) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, value: String) = apply { + paramsBuilder.replaceAdditionalHeaders(name, value) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalHeaders(name, values) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalHeaders */ + fun removeAdditionalHeaders(name: String) = apply { + paramsBuilder.removeAdditionalHeaders(name) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalHeaders */ + fun removeAllAdditionalHeaders(names: Set) = apply { + paramsBuilder.removeAllAdditionalHeaders(names) + } + + /** @see ChatCompletionCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: Map>) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalQueryParam */ + fun putAdditionalQueryParam(key: String, value: String) = apply { + paramsBuilder.putAdditionalQueryParam(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalQueryParams */ + fun putAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.putAdditionalQueryParams(key, values) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, value: String) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, values) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalQueryParams */ + fun removeAdditionalQueryParams(key: String) = apply { + paramsBuilder.removeAdditionalQueryParams(key) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalQueryParams */ + fun removeAllAdditionalQueryParams(keys: Set) = apply { + paramsBuilder.removeAllAdditionalQueryParams(keys) + } + + /** + * Returns an immutable instance of [StructuredChatCompletionCreateParams]. + * + * Further updates to this [Builder] will not mutate the returned instance. + * + * The following fields are required: + * ```java + * .messages() + * .model() + * .responseFormat() + * ``` + * + * @throws IllegalStateException If any required field is unset. + */ + fun build() = + StructuredChatCompletionCreateParams( + checkRequired("responseFormat", responseFormat), + paramsBuilder.build(), + ) + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletionCreateParams<*> && + responseFormat == other.responseFormat && + rawParams == other.rawParams + } + + override fun hashCode(): Int = Objects.hash(responseFormat, rawParams) + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, params=$rawParams}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt new file mode 100644 index 00000000..519596ef --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt @@ -0,0 +1,92 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.fromJson +import com.openai.models.chat.completions.ChatCompletionMessage.FunctionCall +import java.util.Objects +import java.util.Optional + +class StructuredChatCompletionMessage +internal constructor( + val responseFormat: Class, + val chatCompletionMessage: ChatCompletionMessage, +) { + + private val content: JsonField by lazy { + chatCompletionMessage._content().map { fromJson(it, responseFormat) } + } + + /** @see ChatCompletionMessage.content */ + fun content(): Optional = content.getOptional("content") + + /** @see ChatCompletionMessage.refusal */ + fun refusal(): Optional = chatCompletionMessage.refusal() + + /** @see ChatCompletionMessage._role */ + fun _role(): JsonValue = chatCompletionMessage._role() + + /** @see ChatCompletionMessage.annotations */ + fun annotations(): Optional> = + chatCompletionMessage.annotations() + + /** @see ChatCompletionMessage.audio */ + fun audio(): Optional = chatCompletionMessage.audio() + + /** @see ChatCompletionMessage.functionCall */ + @Deprecated("deprecated") + fun functionCall(): Optional = chatCompletionMessage.functionCall() + + /** @see ChatCompletionMessage.toolCalls */ + fun toolCalls(): Optional> = + chatCompletionMessage.toolCalls() + + /** @see ChatCompletionMessage._content */ + fun _content(): JsonField = content + + /** @see ChatCompletionMessage._refusal */ + fun _refusal(): JsonField = chatCompletionMessage._refusal() + + /** @see ChatCompletionMessage._annotations */ + fun _annotations(): JsonField> = + chatCompletionMessage._annotations() + + /** @see ChatCompletionMessage._audio */ + fun _audio(): JsonField = chatCompletionMessage._audio() + + /** @see ChatCompletionMessage._functionCall */ + @Deprecated("deprecated") + fun _functionCall(): JsonField = chatCompletionMessage._functionCall() + + /** @see ChatCompletionMessage._toolCalls */ + fun _toolCalls(): JsonField> = + chatCompletionMessage._toolCalls() + + /** @see ChatCompletionMessage._additionalProperties */ + fun _additionalProperties(): Map = + chatCompletionMessage._additionalProperties() + + /** @see ChatCompletionMessage.validate */ + // `content()` is not included in the validation by the delegate method, so just call it. + fun validate(): ChatCompletionMessage = chatCompletionMessage.validate() + + /** @see ChatCompletionMessage.isValid */ + fun isValid(): Boolean = chatCompletionMessage.isValid() + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletionMessage<*> && + responseFormat == other.responseFormat && + chatCompletionMessage == other.chatCompletionMessage + } + + private val hashCode: Int by lazy { Objects.hash(responseFormat, chatCompletionMessage) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, chatCompletionMessage=$chatCompletionMessage}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 8febcf12..28818c45 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -15,6 +15,8 @@ import com.openai.models.chat.completions.ChatCompletionListPage import com.openai.models.chat.completions.ChatCompletionListParams import com.openai.models.chat.completions.ChatCompletionRetrieveParams import com.openai.models.chat.completions.ChatCompletionUpdateParams +import com.openai.models.chat.completions.StructuredChatCompletion +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams import com.openai.services.blocking.chat.completions.MessageService interface ChatCompletionService { @@ -53,6 +55,16 @@ interface ChatCompletionService { requestOptions: RequestOptions = RequestOptions.none(), ): ChatCompletion + /** @see create */ + fun create( + params: StructuredChatCompletionCreateParams + ): StructuredChatCompletion = + StructuredChatCompletion( + params.responseFormat, + // Normal, non-generic create method call via `ChatCompletionCreateParams`. + create(params.rawParams), + ) + /** * **Starting a new project?** We recommend trying * [Responses](https://platform.openai.com/docs/api-reference/responses) to take advantage of diff --git a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt new file mode 100644 index 00000000..31768c04 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt @@ -0,0 +1,1403 @@ +package com.openai.core + +import com.fasterxml.jackson.annotation.JsonClassDescription +import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.annotation.JsonPropertyDescription +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.AfterTestExecutionCallback +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.RegisterExtension + +/** Tests the [JsonSchemaValidator] and, in passing, tests the [extractSchema] function. */ +internal class JsonSchemaValidatorTest { + companion object { + private const val SCHEMA = "\$schema" + private const val SCHEMA_VER = "https://json-schema.org/draft/2020-12/schema" + private const val DEFS = "\$defs" + private const val REF = "\$ref" + + /** + * `true` to print the schema and validation errors for all executed tests, or `false` to + * print them only for failed tests. + */ + private const val VERBOSE_MODE = false + } + + /** + * A validator that can be used by each unit test. A new validation instance is created for each + * test, as each test is run from its own instance of the test class. If a test fails, any + * validation errors are automatically printed to standard output to aid diagnosis. + */ + val validator = JsonSchemaValidator.create() + + /** + * The schema that was created by the unit test. This may be printed out after a test fails to + * aid in diagnosing the cause of the failure. In that case, this property must be set, or an + * error will occur. However, it will only be printed if the failed test method has the name + * prefix `schemaTest_`, so only test methods with that naming pattern need to set this field. + */ + lateinit var schema: JsonNode + + /** + * An extension to JUnit that prints the [schema] and the validation status (including any + * errors) when a test fails. This applies only to test methods whose names are prefixed with + * `schemaTest_`. An error will occur if [schema] was not set, but this can be avoided by only + * using the method name prefix for test methods that set [schema]. This reporting is intended + * as an aid to diagnosing test failures. + */ + @Suppress("unused") + @RegisterExtension + val printValidationErrorsOnFailure: AfterTestExecutionCallback = + object : AfterTestExecutionCallback { + @Throws(Exception::class) + override fun afterTestExecution(context: ExtensionContext) { + if ( + context.displayName.startsWith("schemaTest_") && + (VERBOSE_MODE || context.executionException.isPresent) + ) { + // Test failed. + println("Schema: ${schema.toPrettyString()}\n") + println("$validator\n") + } + } + } + + // NOTE: In most of these tests, it is assumed that the schema is generated as expected; it is + // not examined in fine detail if the validator succeeds or fails with the expected errors. + + @Test + fun schemaTest_minimalSchema() { + class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_minimalListSchema() { + val s: List = listOf() + + schema = extractSchema(s.javaClass) + validator.validate(schema) + + // FIXME: Currently, the generated schema looks like this: + // { + // "$schema" : "https://json-schema.org/draft/2020-12/schema", + // "type" : "array", + // "items" : { } + // } + // That causes an error, as the `"items"` object is empty when it should be a valid + // sub-schema. Something like this is what is expected: + // { + // "$schema" : "https://json-schema.org/draft/2020-12/schema", + // "type" : "array", + // "items" : { + // "type" : "string" + // } + // } + // It might be presumed that type erasure is the cause of the missing field. However, the + // `schemaTest_listFieldSchema` method (below) seems to be able to produce the expected + // `"items"` object when it is defined as a class property, so, well ... huh? + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_listFieldSchema() { + @Suppress("unused") class X(val s: List) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + // This gives a root schema with `"type" : "string"` and `"const" : "HELLO"` + // Unfortunately, an "enum class" cannot be defined within a function or within a class within + // a function. + @Suppress("unused") + enum class MinimalEnum1 { + HELLO + } + + @Test + fun schemaTest_minimalEnumSchema1() { + schema = extractSchema(MinimalEnum1::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + // This gives a root schema with `"type" : "string"` and `"enum" : [ "HELLO", "WORLD" ]` + @Suppress("unused") + enum class MinimalEnum2 { + HELLO, + WORLD, + } + + @Test + fun schemaTest_minimalEnumSchema2() { + schema = extractSchema(MinimalEnum2::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_nonStringEnum() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "integer", + "enum" : [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchema() { + @Suppress("unused") class X(val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalString() { + // Using an `Optional` will result in this JSON: `"type" : [ "string", "null" ]`. + // That is supported by the OpenAI Structured Outputs API spec, as long as the field is also + // marked as required. Though required, it is still allowed for the field to be explicitly + // set to `"null"`. + @Suppress("unused") class X(val s: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalBoolean() { + @Suppress("unused") class X(val b: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalInteger() { + @Suppress("unused") class X(val i: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalNumber() { + @Suppress("unused") class X(val n: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_arraySchemaFromOptional() { + @Suppress("unused") class X(val s: Optional>) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_arrayTypeMissingItems() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "array" + } + """ + ) + validator.validate(schema) + + // Check once here that "validator.isValid()" returns "false" when there is an error. In + // the other tests, there is no need to repeat this assertion, as it would be redundant. + assertThat(validator.isValid()).isFalse + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'items' field is missing or is not an object.") + } + + @Test + fun schemaTest_arrayTypeWithWrongItemsType() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "array", + "items" : [ "should_not_be_an_array" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'items' field is missing or is not an object.") + } + + @Test + @Suppress("unused") + fun schemaTest_objectSubSchemaFromOptional() { + class X(val s: Optional) + class Y(val x: Optional) + + schema = extractSchema(Y::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_badOptionalTypeNotArray() { + // Testing more for code coverage than for anything expected to go wrong in practice. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : { "type" : "string" } + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'type' field is not a type name or array of type names.") + } + + @Test + fun schemaTest_badOptionalTypeNoNull1() { + // Testing more for code coverage than for anything expected to go wrong in practice. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeNoNull2() { + // If "type" is an array, one of the two "type" values must be "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", "number" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected one type name and one \"null\".") + } + + @Test + fun schemaTest_badOptionalTypeNoNull3() { + // If "type" is an array, there must be two type values only, one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", "number", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeNoStringTypeNames() { + // If "type" is an array, there must be two type values only, one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", null ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeAllNull() { + // If "type" is an array, there must be two type values only, and only one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "null", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected one type name and one \"null\".") + } + + @Test + fun schemaTest_badOptionalTypeUnknown() { + // If "type" is an array, there must be two type values only, and only one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "unknown", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#/type: Unsupported 'type' value: 'unknown'.") + } + + @Test + fun schemaTest_goodOptionalTypeNullFirst() { + // The validator should be lenient about the order of the null/not-null types in the array. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "null", "string" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinyRecursiveSchema() { + @Suppress("unused") class X(val s: String, val x: X) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_unsupportedKeywords() { + // OpenAI lists a set of keywords that are not allowed, but the set is not exhaustive. Check + // that everything named in that set is identified as not allowed, as that is the minimum + // level of validation expected. Check at the root schema and a sub-schema. There is no need + // to match the keywords to their expected schema types or be concerned about the values of + // the keyword fields, which makes testing easier. + val keywordsNotAllowed = + listOf( + "minLength", + "maxLength", + "pattern", + "format", + "minimum", + "maximum", + "multipleOf", + "patternProperties", + "unevaluatedProperties", + "propertyNames", + "minProperties", + "maxProperties", + "unevaluatedItems", + "contains", + "minContains", + "maxContains", + "minItems", + "maxItems", + "uniqueItems", + ) + val notAllowedUses = keywordsNotAllowed.joinToString(", ") { "\"$it\" : \"\"" } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "x" : { + "type" : "string", + $notAllowedUses + } + }, + $notAllowedUses, + "additionalProperties" : false, + "required" : [ "x" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(keywordsNotAllowed.size * 2) + keywordsNotAllowed.forEachIndexed { index, keyword -> + assertThat(validator.errors()[index]) + .isEqualTo("#: Use of '$keyword' is not supported here.") + assertThat(validator.errors()[index + keywordsNotAllowed.size]) + .isEqualTo("#/properties/x: Use of '$keyword' is not supported here.") + } + } + + @Test + fun schemaTest_propertyNotMarkedRequired() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayNull() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : null + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_additionalPropertiesMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'additionalProperties' field is missing or is not set to 'false'.") + } + + @Test + fun schemaTest_additionalPropertiesTrue() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : true, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'additionalProperties' field is missing or is not set to 'false'.") + } + + @Test + fun schemaTest_objectPropertiesMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + // TODO: Decide if this is the expected behavior, i.e., that it is OK for an "object" schema + // to have no "properties". + assertThat(validator.isValid()).isTrue() + } + + @Test + fun schemaTest_objectPropertiesNotObject() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : [ "name", "age" ], + "additionalProperties" : false, + "required" : [ "name", "age" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is not a non-empty object.") + } + + @Test + fun schemaTest_objectPropertiesEmpty() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { }, + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is not a non-empty object.") + } + + @Test + fun schemaTest_anyOfInRootSchema() { + // OpenAI does not allow `"anyOf"` to appear at the root level of a schema. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "anyOf" : [ { + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : ["name"] + }, { + "type" : "array", + "items" : { + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : ["name"] + } + } ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#: Root schema contains 'anyOf' field.") + } + + @Test + fun schemaTest_anyOfNotArray() { + // Unlikely that this can occur in a generated schema, so this is more about code coverage. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { + "anyOf" : { + "type" : "string" + } + } + }, + "additionalProperties" : false, + "required" : ["name"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/name: 'anyOf' field is not a non-empty array.") + } + + @Test + fun schemaTest_anyOfIsEmptyArray() { + // Unlikely that this can occur in a generated schema, so this is more about code coverage. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { + "anyOf" : [ ] + } + }, + "additionalProperties" : false, + "required" : ["name"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/name: 'anyOf' field is not a non-empty array.") + } + + @Test + fun schemaTest_anyOfInSubSchemaArray() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "value" : { + "anyOf" : [ + { "type" : "string" }, + { "type" : "number" } + ] + } + }, + "additionalProperties" : false, + "required" : ["value"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_noSchemaFieldRootSchema() { + @Suppress("unused") class X(val s: String) + + schema = extractSchema(X::class.java) + (schema as ObjectNode).remove(SCHEMA) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#: Root schema missing '$SCHEMA' field.") + } + + @Test + @Suppress("unused") + fun schemaTest_deepNestingAtLimit() { + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + + schema = extractSchema(Y::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + @Suppress("unused") + fun schemaTest_deepNestingBeyondLimit() { + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + schema = extractSchema(Z::class.java) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).contains("Current nesting depth is 6, but maximum is 5.") + } + + @Test + fun schemaTest_stringEnum250ValueOverSizeLimit() { + // OpenAI specification: "For a single enum property with string values, the total string + // length of all enum values cannot exceed 7,500 characters when there are more than 250 + // enum values." + + // This test creates an enum with exactly 250 string values with more than 7,500 characters + // in total (31 characters per value for a total of 7,750 characters). No error is expected. + val values = (1..250).joinToString(", ") { "\"%s%03d\"".format("x".repeat(28), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_stringEnum251ValueUnderSizeLimit() { + // This test creates an enum with exactly 251 string values with fewer than 7,500 characters + // in total (29 characters per value for a total of 7,279 characters). No error is expected. + val values = (1..251).joinToString(", ") { "\"%s%03d\"".format("x".repeat(26), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_stringEnum251ValueOverSizeLimit() { + // This test creates an enum with exactly 251 string values with fewer than 7,500 characters + // in total (30 characters per value for a total of 7,530 characters). An error is expected. + val values = (1..251).joinToString(", ") { "\"%s%03d\"".format("x".repeat(27), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo( + "#/enum: Total string length (7530) of values of an enum " + + "with 251 values exceeds limit of 7500." + ) + } + + @Test + fun schemaTest_totalEnumValuesAtLimit() { + // OpenAI specification: "A schema may have up to 500 enum values across all enum + // properties." + + // This test creates two enums with a total of 500 values. The total string length of the + // values is well within the limits (2,000 characters). + val valuesA = (1..250).joinToString(", ") { "\"a%03d\"".format(it) } + val valuesB = (1..250).joinToString(", ") { "\"b%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "string", + "enum" : [ $valuesA ] + }, + "b" : { + "type" : "string", + "enum" : [ $valuesB ] + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_totalEnumValuesOverLimit() { + // This test creates two enums with a total of 501 values. The total string length of the + // values is well within the limits (2,004 characters). + val valuesA = (1..250).joinToString(", ") { "\"a%03d\"".format(it) } + val valuesB = (1..251).joinToString(", ") { "\"b%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "string", + "enum" : [ $valuesA ] + }, + "b" : { + "type" : "string", + "enum" : [ $valuesB ] + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total number of enum values (501) exceeds limit of 500.") + } + + @Test + fun schemaTest_maxObjectPropertiesAtLimit() { + // This test creates two object schemas with a total of 100 object properties. OpenAI does + // not support more than 100 properties total in the whole schema. Two objects are used to + // ensure that counting is not done per object, but across all objects. Note that each + // object schema is itself a property, so there are two properties at the top level and 49 + // properties each at the next level. No error is expected, as the limit is not exceeded. + val propUses = + (1..49).joinToString(", ") { "\"x%02d\" : { \"type\" : \"string\" }".format(it) } + val propNames = (1..49).joinToString(", ") { "\"x%02d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + }, + "b" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_maxObjectPropertiesOverLimit() { + // This test creates two object schemas with a total of 101 object properties. OpenAI does + // not support more than 100 properties total in the whole schema. Expect an error. + val propUses = + (1..49).joinToString(", ") { "\"x_%02d\" : { \"type\" : \"string\" }".format(it) } + val propNames = (1..49).joinToString(", ") { "\"x_%02d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + }, + "b" : { + "type" : "object", + "properties" : { + $propUses, + "property_101" : { "type" : "string" } + }, + "required" : [ $propNames, "property_101" ], + "additionalProperties" : false + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total number of object properties (101) exceeds limit of 100.") + } + + @Test + fun schemaTest_maxStringLengthAtLimit() { + // OpenAI specification: "In a schema, total string length of all property names, definition + // names, enum values, and const values cannot exceed 15,000 characters." + // + // This test creates a schema with many property names, definition names, enum values, and + // const values calculated to have a total string length of 15,000 characters. No error is + // expected. + // + // The test creates a schema that looks like the following, with the numbers adjusted to + // achieve a total of 15,000 characters for the relevant elements. + // + // { + // "$schema" : "...", + // "$defs" : { + // "d_001" : { + // "type" : "string", + // "const" : "c_001" + // }, + // ..., + // "d_nnn" : { + // "type" : "string", + // "const" : "c_nnn" + // } + // }, + // "type" : "object", + // "properties" : { + // "p_001" : { + // "type" : "string", + // "enum" : [ "eeeee..._001", ..., "eeeee..._nnn" ] + // }, + // ..., + // "p_nnn" : { + // "type" : "string", + // "enum" : [ "eeeee..._001", ..., "eeeee..._nnn" ] + // } + // }, + // "required" : [ "p_001", ..., "p_nnn" ], + // "additionalProperties" : false + // } + + val numDefs = 65 // Each also has one "const" value. + val numProps = 70 // Each also has "numEnumValues" enum values. + val nameLen = 5 // Length of names of definitions, properties and const values. + val numEnumValues = 5 // numProps * numEnumValues <= 500 limit (OpenAI) + val enumValueLen = 40 // Length of enum values. + val expectedTotalStringLength = + nameLen * (numProps + numDefs * 2) + numProps * enumValueLen * numEnumValues + + val enumValues = + (1..numEnumValues).joinToString(", ") { "\"%s_%03d\"".format("e".repeat(36), it) } + val defs = + (1..numDefs).joinToString(", ") { + "\"d_%03d\" : { \"type\" : \"string\", \"const\" : \"c_%03d\" }".format(it, it) + } + val props = + (1..numProps).joinToString(", ") { + "\"p_%03d\" : { \"type\" : \"string\", \"enum\" : [ $enumValues ] }".format(it) + } + val propNames = (1..numProps).joinToString(", ") { "\"p_%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { $defs }, + "type" : "object", + "properties" : { $props }, + "required" : [ $propNames ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(expectedTotalStringLength).isEqualTo(15_000) // Exactly on the limit. + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_maxStringLengthOverLimit() { + // OpenAI specification: "In a schema, total string length of all property names, definition + // names, enum values, and const values cannot exceed 15,000 characters." + // + // This test creates a schema with many property names, definition names, enum values, and + // const values calculated to have a total string length of just over 15,000 characters. An + // error is expected. + + val numDefs = 66 // Each also has one "const" value. + val numProps = 70 // Each also has "numEnumValues" enum values. + val numEnumValues = 5 // numProps * numEnumValues <= 500 limit (OpenAI) + val nameLen = 5 // Length of names of definitions, properties and const values. + val enumValueLen = 40 // Length of enum values. + val expectedTotalStringLength = + nameLen * (numProps + numDefs * 2) + numProps * enumValueLen * numEnumValues + + val enumValues = + (1..numEnumValues).joinToString(", ") { "\"%s_%03d\"".format("e".repeat(36), it) } + val defs = + (1..numDefs).joinToString(", ") { + "\"d_%03d\" : { \"type\" : \"string\", \"const\" : \"c_%03d\" }".format(it, it) + } + val props = + (1..numProps).joinToString(", ") { + "\"p_%03d\" : { \"type\" : \"string\", \"enum\" : [ $enumValues ] }".format(it) + } + val propNames = (1..numProps).joinToString(", ") { "\"p_%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { $defs }, + "type" : "object", + "properties" : { $props }, + "required" : [ $propNames ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(expectedTotalStringLength).isGreaterThan(15_000) + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total string length of all values (15010) exceeds limit of 15000.") + } + + @Test + fun schemaTest_annotatedWithJsonClassDescription() { + // Add a "description" to the root schema using an annotation. + @JsonClassDescription("A simple schema.") class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val desc = schema.get("description") + + assertThat(validator.isValid()).isTrue + assertThat(desc).isNotNull + assertThat(desc.isTextual).isTrue + assertThat(desc.asText()).isEqualTo("A simple schema.") + } + + @Test + fun schemaTest_annotatedWithJsonPropertyDescription() { + // Add a "description" to the property using an annotation. + @Suppress("unused") class X(@get:JsonPropertyDescription("A string value.") val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("s") + val desc = stringProperty.get("description") + + assertThat(validator.isValid()).isTrue + assertThat(desc).isNotNull + assertThat(desc.isTextual).isTrue + assertThat(desc.asText()).isEqualTo("A string value.") + } + + @Test + fun schemaTest_annotatedWithJsonProperty() { + // Override the default name of the property using the annotation. + @Suppress("unused") class X(@get:JsonProperty("a_string") val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("a_string") + + assertThat(validator.isValid()).isTrue + assertThat(stringProperty).isNotNull + } + + @Test + fun schemaTest_annotatedWithJsonPropertyRejectDefaultValue() { + // Set a default value for the property. It should be ignored when the schema is generated, + // as default property values are not supported in OpenAI JSON schemas. (The Victools docs + // have examples of how to add support for this default values via annotations or initial + // values, should support for default values be needed in the future.) + // + // Lack of support is not mentioned in the specification, but see the evidence at: + // https://engineering.fractional.ai/openai-structured-output-fixes + @Suppress("unused") + class X( + @get:JsonProperty(defaultValue = "default_value_1") val s: String = "default_value_2" + ) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("s") + + assertThat(validator.isValid()).isTrue + assertThat(stringProperty).isNotNull + assertThat(stringProperty.get("default")).isNull() + } + + @Test + fun schemaTest_annotatedWithJsonIgnore() { + // Override the default name of the property using the annotation. + @Suppress("unused") class X(@get:JsonIgnore val s1: String, val s2: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val s1Property = properties.get("s1") + val s2Property = properties.get("s2") + + assertThat(validator.isValid()).isTrue + assertThat(s1Property).isNull() + assertThat(s2Property).isNotNull + } + + @Test + fun schemaTest_emptyDefinitions() { + // Be lenient about empty definitions. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "type" : "string" + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_referenceMissingReferent() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "$REF" : "#/$DEFS/Person" + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/$REF: Invalid or unsupported reference: '#/$DEFS/Person'.") + } + + @Test + fun schemaTest_referenceFieldIsNotTextual() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "$REF" : 42 + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#/$REF: '$REF' field is not a text value.") + } + + @Test + fun validatorBeforeValidation() { + assertThat(validator.errors()).isEmpty() + assertThat(validator.isValid()).isFalse + } + + @Test + fun validatorReused() { + class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Should fail if an attempt is made to reuse the validator. + assertThatThrownBy { validator.validate(schema) } + .isExactlyInstanceOf(IllegalStateException::class.java) + .hasMessageContaining("Validation already complete.") + } + + @Test + @Suppress("unused") + fun schemaTest_largeLaureatesSchema() { + // This covers many cases: large and complex "$defs", resolution of references, recursive + // references, etc. The output is assumed to be good (it has been checked by eye) and the + // test just shows that the validator can handle the complexity without crashing or emitting + // spurious errors. + class Name(val givenName: String, val familyName: String) + + class Person( + @get:JsonPropertyDescription("The name of the person.") val name: Name, + @get:JsonProperty(value = "date_of_birth", defaultValue = "unknown_1") + @get:JsonPropertyDescription("The date of birth of the person.") + var dateOfBirth: String, + @get:JsonPropertyDescription("The country of citizenship of the person.") + var nationality: String, + // A child being a `Person` results in a recursive schema. + @get:JsonPropertyDescription("The children (if any) of the person.") + val children: List, + ) { + @get:JsonPropertyDescription("The other name of the person.") + var otherName: Name = Name("Bob", "Smith") + } + + class Laureate( + val laureate: Person, + val majorContribution: String, + val yearOfWinning: String, + @get:JsonIgnore val favoriteColor: String, + ) + + class Laureates( + // Two lists results in a `Laureate` definition that is referenced in the schema. + var laureates1901to1950: List, + var laureates1951to2025: List, + ) + + schema = extractSchema(Laureates::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt index e47aebd6..fb52ffc6 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt @@ -347,4 +347,36 @@ internal class ChatCompletionCreateParamsTest { ) assertThat(body.model()).isEqualTo(ChatModel.GPT_4_1) } + + @Test + fun structuredOutputsBuilder() { + class X(val s: String) + + // Only interested in a few things: + // - Does the `Builder` type change when `responseFormat(Class)` is called? + // - Are values already set on the "old" `Builder` preserved in the change-over? + // - Can new values be set on the "new" `Builder` alongside the "old" values? + val params = + ChatCompletionCreateParams.builder() + .addDeveloperMessage("dev message") + .model(ChatModel.GPT_4_1) + .responseFormat(X::class.java) // Creates and return a new builder. + .addSystemMessage("sys message") + .build() + + val body = params.rawParams._body() + + assertThat(params).isInstanceOf(StructuredChatCompletionCreateParams::class.java) + assertThat(params.responseFormat).isEqualTo(X::class.java) + assertThat(body.messages()) + .containsExactly( + ChatCompletionMessageParam.ofDeveloper( + ChatCompletionDeveloperMessageParam.builder().content("dev message").build() + ), + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content("sys message").build() + ), + ) + assertThat(body.model()).isEqualTo(ChatModel.GPT_4_1) + } } diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt new file mode 100644 index 00000000..4abd66b6 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt @@ -0,0 +1,499 @@ +package com.openai.models.chat.completions + +import com.openai.core.fromClass +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.models.ChatModel +import com.openai.models.FunctionDefinition +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_FIELD +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_VALUE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.MESSAGE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.OPTIONAL +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.STRING +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.X +import java.lang.reflect.Method +import kotlin.collections.plus +import kotlin.reflect.full.declaredFunctions +import kotlin.reflect.jvm.javaMethod +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletionCreateParams] class (delegator) and its delegation of + * most functions to a wrapped [ChatCompletionCreateParams] (delegate). It is the `Builder` class of + * each main class that is involved in the delegation. The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionCreateParamsTest { + companion object { + private fun checkOneDelegationWrite( + delegator: Any, + mockDelegate: Any, + testCase: DelegationWriteTestCase, + ) { + invokeMethod(findDelegationMethod(delegator, testCase), delegator, testCase) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockDelegate, times(1)).apply { + invokeMethod(findDelegationMethod(mockDelegate, testCase), mockDelegate, testCase) + } + verifyNoMoreInteractions(mockDelegate) + } + + private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteTestCase) { + val numParams = testCase.inputValues.size + val inputValue1 = testCase.inputValues[0] + val inputValue2 = testCase.inputValues.getOrNull(1) + + when (numParams) { + 1 -> method.invoke(target, inputValue1) + 2 -> method.invoke(target, inputValue1, inputValue2) + else -> fail { "Unexpected number of function parameters ($numParams)." } + } + } + + /** + * Finds the java method matching the test case's function name and parameter types in the + * delegator or delegate `target`. + */ + private fun findDelegationMethod(target: Any, testCase: DelegationWriteTestCase): Method { + val numParams = testCase.inputValues.size + val inputValue1: Any? = testCase.inputValues[0] + val inputValue2 = if (numParams > 1) testCase.inputValues[1] else null + + val method = + when (numParams) { + 1 -> + if (inputValue1 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + ) + } else { + // Only the first parameter may be nullable and only if it is the only + // parameter. If the first parameter is nullable, it will be the only + // function of the same name with a nullable first parameter. To handle + // the potentially nullable first parameter, Kotlin reflection is + // needed. This allows a function `f(Boolean)` to be distinguished from + // `f(Boolean?)`. For the tests, if the parameter type is nullable, the + // parameter value will always be `null` (if not, the function with the + // nullable parameter would not be matched). + // + // Using Kotlin reflection, the first parameter (zero index) is `this` + // object, so start matching from the second parameter onwards. + target::class + .declaredFunctions + .find { + it.name == testCase.functionName && + it.parameters[1].type.isMarkedNullable + } + ?.javaMethod + } + 2 -> + if (inputValue1 != null && inputValue2 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + toJavaType(inputValue2.javaClass), + ) + } else { + // There are no instances where there are two parameters and one of them + // is nullable. + fail { "Function $testCase second parameter must not be null." } + } + else -> fail { "Function $testCase has unsupported number of parameters." } + } + + // Using `if` and `fail`, so the compiler knows the code will not continue and can infer + // that `delegationMethod` is not null. It cannot do this for `assertThat...isNotNull`. + if (method == null) { + fail { "Function $testCase cannot be found in $target." } + } + + return method + } + + private fun findJavaMethod( + clazz: Class<*>, + methodName: String, + vararg parameterTypes: Class<*>, + ): Method? = + clazz.declaredMethods.firstOrNull { method -> + method.name == methodName && + method.parameterTypes.size == parameterTypes.size && + method.parameterTypes.indices.all { index -> + (parameterTypes[index].isPrimitive && + method.parameterTypes[index] == parameterTypes[index]) || + method.parameterTypes[index].isAssignableFrom(parameterTypes[index]) + } + } + + /** + * Returns the Java type to use when matching type parameters for a Java method. The type is + * the type of the input value that will be used when the method is invoked. For most types, + * the given type is returned. However, if the type represents a Kotlin primitive, it will + * be converted to a Java primitive. This allows matching of methods with parameter types + * that are non-nullable Kotlin primitives. If not translated, methods with parameter types + * that are nullable Kotlin primitives would always be matched instead. + */ + private fun toJavaType(type: Class<*>) = + when (type) { + // This only needs to cover the types used in the test cases. + java.lang.Long::class.java -> java.lang.Long.TYPE + java.lang.Boolean::class.java -> java.lang.Boolean.TYPE + java.lang.Double::class.java -> java.lang.Double.TYPE + else -> type + } + + private val NULLABLE = null + private const val BOOLEAN: Boolean = true + private val NULLABLE_BOOLEAN: Boolean? = null + private const val LONG: Long = 42L + private val NULLABLE_LONG: Long? = null + private const val DOUBLE: Double = 42.0 + private val NULLABLE_DOUBLE: Double? = null + private val LIST = listOf(STRING) + private val SET = setOf(STRING) + private val MAP = mapOf(STRING to STRING) + + private val CHAT_MODEL = ChatModel.GPT_4 + + private val USER_MESSAGE_PARAM = + ChatCompletionUserMessageParam.builder().content(STRING).build() + private val DEV_MESSAGE_PARAM = + ChatCompletionDeveloperMessageParam.builder().content(STRING).build() + private val SYS_MESSAGE_PARAM = + ChatCompletionSystemMessageParam.builder().content(STRING).build() + private val ASSIST_MESSAGE_PARAM = + ChatCompletionAssistantMessageParam.builder().content(STRING).build() + private val TOOL_MESSAGE_PARAM = + ChatCompletionToolMessageParam.builder().content(STRING).toolCallId(STRING).build() + private val FUNC_MESSAGE_PARAM = + ChatCompletionFunctionMessageParam.builder().content(STRING).name(STRING).build() + private val MESSAGE_PARAM = ChatCompletionMessageParam.ofUser(USER_MESSAGE_PARAM) + + private val DEV_MESSAGE_PARAM_CONTENT = + ChatCompletionDeveloperMessageParam.Content.ofText(STRING) + private val SYS_MESSAGE_PARAM_CONTENT = + ChatCompletionSystemMessageParam.Content.ofText(STRING) + private val USER_MESSAGE_PARAM_CONTENT = + ChatCompletionUserMessageParam.Content.ofText(STRING) + + private val PARAMS_BODY = + ChatCompletionCreateParams.Body.builder() + .messages(listOf(MESSAGE_PARAM)) + .model(CHAT_MODEL) + .build() + private val WEB_SEARCH_OPTIONS = + ChatCompletionCreateParams.WebSearchOptions.builder().build() + + private val FUNCTION_CALL_MODE = + ChatCompletionCreateParams.FunctionCall.FunctionCallMode.AUTO + private val FUNCTION_CALL_OPTION = + ChatCompletionFunctionCallOption.builder().name(STRING).build() + private val FUNCTION_CALL = + ChatCompletionCreateParams.FunctionCall.ofFunctionCallOption(FUNCTION_CALL_OPTION) + + private val FUNCTION = ChatCompletionCreateParams.Function.builder().name(STRING).build() + private val METADATA = ChatCompletionCreateParams.Metadata.builder().build() + private val MODALITY = ChatCompletionCreateParams.Modality.TEXT + private val FUNCTION_DEFINITION = FunctionDefinition.builder().name(STRING).build() + private val TOOL = ChatCompletionTool.builder().function(FUNCTION_DEFINITION).build() + + private val NAMED_TOOL_CHOICE_FUNCTION = + ChatCompletionNamedToolChoice.Function.builder().name(STRING).build() + private val NAMED_TOOL_CHOICE = + ChatCompletionNamedToolChoice.builder().function(NAMED_TOOL_CHOICE_FUNCTION).build() + private val TOOL_CHOICE_OPTION_AUTO = ChatCompletionToolChoiceOption.Auto.AUTO + private val TOOL_CHOICE_OPTION = + ChatCompletionToolChoiceOption.ofAuto(TOOL_CHOICE_OPTION_AUTO) + + private val HEADERS = Headers.builder().build() + private val QUERY_PARAMS = QueryParams.builder().build() + + // Want `vararg`, so cannot use `data class`. Need a custom `toString`, anyway. + class DelegationWriteTestCase(val functionName: String, vararg val inputValues: Any?) { + /** + * Gets the string representation that identifies the test function when running JUnit. + */ + override fun toString(): String = + "$functionName(${inputValues.joinToString(", ") { + it?.javaClass?.simpleName ?: "null" + }})" + } + + // The list order follows the declaration order in `ChatCompletionCreateParams.Builder` for + // easier maintenance. + @JvmStatic + fun builderDelegationTestCases() = + listOf( + DelegationWriteTestCase("body", PARAMS_BODY), + DelegationWriteTestCase("messages", LIST), + DelegationWriteTestCase("messages", JSON_FIELD), + DelegationWriteTestCase("addMessage", MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", DEV_MESSAGE_PARAM), + DelegationWriteTestCase("addDeveloperMessage", DEV_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addDeveloperMessage", STRING), + DelegationWriteTestCase("addDeveloperMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", SYS_MESSAGE_PARAM), + DelegationWriteTestCase("addSystemMessage", SYS_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addSystemMessage", STRING), + DelegationWriteTestCase("addSystemMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", USER_MESSAGE_PARAM), + DelegationWriteTestCase("addUserMessage", USER_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addUserMessage", STRING), + DelegationWriteTestCase("addUserMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", ASSIST_MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", MESSAGE), + DelegationWriteTestCase("addMessage", TOOL_MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", FUNC_MESSAGE_PARAM), + DelegationWriteTestCase("model", CHAT_MODEL), + DelegationWriteTestCase("model", JSON_FIELD), + DelegationWriteTestCase("model", STRING), + DelegationWriteTestCase("audio", NULLABLE), + DelegationWriteTestCase("audio", OPTIONAL), + DelegationWriteTestCase("audio", JSON_FIELD), + DelegationWriteTestCase("frequencyPenalty", NULLABLE_DOUBLE), + DelegationWriteTestCase("frequencyPenalty", DOUBLE), + DelegationWriteTestCase("frequencyPenalty", OPTIONAL), + DelegationWriteTestCase("frequencyPenalty", JSON_FIELD), + DelegationWriteTestCase("functionCall", FUNCTION_CALL), + DelegationWriteTestCase("functionCall", JSON_FIELD), + DelegationWriteTestCase("functionCall", FUNCTION_CALL_MODE), + DelegationWriteTestCase("functionCall", FUNCTION_CALL_OPTION), + DelegationWriteTestCase("functions", LIST), + DelegationWriteTestCase("functions", JSON_FIELD), + DelegationWriteTestCase("addFunction", FUNCTION), + DelegationWriteTestCase("logitBias", NULLABLE), + DelegationWriteTestCase("logitBias", OPTIONAL), + DelegationWriteTestCase("logitBias", JSON_FIELD), + DelegationWriteTestCase("logprobs", NULLABLE_BOOLEAN), + DelegationWriteTestCase("logprobs", BOOLEAN), + DelegationWriteTestCase("logprobs", OPTIONAL), + DelegationWriteTestCase("logprobs", JSON_FIELD), + DelegationWriteTestCase("maxCompletionTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxCompletionTokens", LONG), + DelegationWriteTestCase("maxCompletionTokens", OPTIONAL), + DelegationWriteTestCase("maxCompletionTokens", JSON_FIELD), + DelegationWriteTestCase("maxTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxTokens", LONG), + DelegationWriteTestCase("maxTokens", OPTIONAL), + DelegationWriteTestCase("maxTokens", JSON_FIELD), + DelegationWriteTestCase("metadata", METADATA), + DelegationWriteTestCase("metadata", OPTIONAL), + DelegationWriteTestCase("metadata", JSON_FIELD), + DelegationWriteTestCase("modalities", LIST), + DelegationWriteTestCase("modalities", OPTIONAL), + DelegationWriteTestCase("modalities", JSON_FIELD), + DelegationWriteTestCase("addModality", MODALITY), + DelegationWriteTestCase("n", NULLABLE_LONG), + DelegationWriteTestCase("n", LONG), + DelegationWriteTestCase("n", OPTIONAL), + DelegationWriteTestCase("n", JSON_FIELD), + DelegationWriteTestCase("parallelToolCalls", BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", JSON_FIELD), + DelegationWriteTestCase("prediction", NULLABLE), + DelegationWriteTestCase("prediction", OPTIONAL), + DelegationWriteTestCase("prediction", JSON_FIELD), + DelegationWriteTestCase("presencePenalty", NULLABLE_DOUBLE), + DelegationWriteTestCase("presencePenalty", DOUBLE), + DelegationWriteTestCase("presencePenalty", OPTIONAL), + DelegationWriteTestCase("presencePenalty", JSON_FIELD), + DelegationWriteTestCase("reasoningEffort", NULLABLE), + DelegationWriteTestCase("reasoningEffort", OPTIONAL), + DelegationWriteTestCase("reasoningEffort", JSON_FIELD), + // `responseFormat()` is a special case and has its own unit test. + DelegationWriteTestCase("seed", NULLABLE_LONG), + DelegationWriteTestCase("seed", LONG), + DelegationWriteTestCase("seed", OPTIONAL), + DelegationWriteTestCase("seed", JSON_FIELD), + DelegationWriteTestCase("serviceTier", NULLABLE), + DelegationWriteTestCase("serviceTier", OPTIONAL), + DelegationWriteTestCase("serviceTier", JSON_FIELD), + DelegationWriteTestCase("stop", NULLABLE), + DelegationWriteTestCase("stop", OPTIONAL), + DelegationWriteTestCase("stop", JSON_FIELD), + DelegationWriteTestCase("stop", STRING), + DelegationWriteTestCase("stopOfStrings", LIST), + DelegationWriteTestCase("store", NULLABLE_BOOLEAN), + DelegationWriteTestCase("store", BOOLEAN), + DelegationWriteTestCase("store", OPTIONAL), + DelegationWriteTestCase("store", JSON_FIELD), + DelegationWriteTestCase("streamOptions", NULLABLE), + DelegationWriteTestCase("streamOptions", OPTIONAL), + DelegationWriteTestCase("streamOptions", JSON_FIELD), + DelegationWriteTestCase("temperature", NULLABLE_DOUBLE), + DelegationWriteTestCase("temperature", DOUBLE), + DelegationWriteTestCase("temperature", OPTIONAL), + DelegationWriteTestCase("temperature", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTION), + DelegationWriteTestCase("toolChoice", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTION_AUTO), + DelegationWriteTestCase("toolChoice", NAMED_TOOL_CHOICE), + DelegationWriteTestCase("tools", LIST), + DelegationWriteTestCase("tools", JSON_FIELD), + DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("topLogprobs", NULLABLE_LONG), + DelegationWriteTestCase("topLogprobs", LONG), + DelegationWriteTestCase("topLogprobs", OPTIONAL), + DelegationWriteTestCase("topLogprobs", JSON_FIELD), + DelegationWriteTestCase("topP", NULLABLE_DOUBLE), + DelegationWriteTestCase("topP", DOUBLE), + DelegationWriteTestCase("topP", OPTIONAL), + DelegationWriteTestCase("topP", JSON_FIELD), + DelegationWriteTestCase("user", STRING), + DelegationWriteTestCase("user", JSON_FIELD), + DelegationWriteTestCase("webSearchOptions", WEB_SEARCH_OPTIONS), + DelegationWriteTestCase("webSearchOptions", JSON_FIELD), + DelegationWriteTestCase("additionalBodyProperties", MAP), + DelegationWriteTestCase("putAdditionalBodyProperty", STRING, JSON_VALUE), + DelegationWriteTestCase("putAllAdditionalBodyProperties", MAP), + DelegationWriteTestCase("removeAdditionalBodyProperty", STRING), + DelegationWriteTestCase("removeAllAdditionalBodyProperties", SET), + DelegationWriteTestCase("additionalHeaders", HEADERS), + DelegationWriteTestCase("additionalHeaders", MAP), + DelegationWriteTestCase("putAdditionalHeader", STRING, STRING), + DelegationWriteTestCase("putAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("putAllAdditionalHeaders", MAP), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("replaceAllAdditionalHeaders", MAP), + DelegationWriteTestCase("removeAdditionalHeaders", STRING), + DelegationWriteTestCase("removeAllAdditionalHeaders", SET), + DelegationWriteTestCase("additionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("additionalQueryParams", MAP), + DelegationWriteTestCase("putAdditionalQueryParam", STRING, STRING), + DelegationWriteTestCase("putAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("putAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("removeAdditionalQueryParams", STRING), + DelegationWriteTestCase("removeAllAdditionalQueryParams", SET), + ) + } + + // New instances of the `mockBuilderDelegate` and `builderDelegator` are required for each test + // case (each test case runs in its own instance of the test class). + val mockBuilderDelegate: ChatCompletionCreateParams.Builder = + mock(ChatCompletionCreateParams.Builder::class.java) + val builderDelegator = + StructuredChatCompletionCreateParams.builder().inject(mockBuilderDelegate) + + @Test + fun allBuilderDelegateFunctionsExistInDelegator() { + // The delegator class does not implement the various `responseFormat` functions of the + // delegate class. + StructuredChatCompletionTest.checkAllDelegation( + ChatCompletionCreateParams.Builder::class, + StructuredChatCompletionCreateParams.Builder::class, + "responseFormat", + ) + } + + @Test + fun allBuilderDelegatorFunctionsExistInDelegate() { + // The delegator implements a different `responseFormat` function from those overloads in + // the delegate class. + StructuredChatCompletionTest.checkAllDelegation( + StructuredChatCompletionCreateParams.Builder::class, + ChatCompletionCreateParams.Builder::class, + "responseFormat", + ) + } + + @Test + fun allBuilderDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. There are many overloaded functions, so the + // approach here is to build a list (_not_ a set) of all function names and then "subtract" + // those for which tests are defined and see what remains. For example, there are (at this + // time) eight `addMessage` functions, so there must be eight tests defined for functions + // named `addMessage` that will be subtracted from the list of functions matching that name. + // Parameter types are not checked, as that is awkward and probably overkill. Therefore, + // this scheme is not reliable if a function is tested more than once. + val exceptionalTestedFns = listOf("responseFormat") + val testedFns = + (builderDelegationTestCases().map { it.functionName } + exceptionalTestedFns) + .toMutableList() + val nonDelegatingFns = listOf("build", "wrap", "inject") + + val delegatorFns = + StructuredChatCompletionCreateParams.Builder::class.declaredFunctions.toMutableList() + + // Making concurrent modifications to the list, so using an `Iterator`. + val i = delegatorFns.iterator() + + while (i.hasNext()) { + val functionName = i.next().name + + if (functionName in testedFns) { + testedFns.remove(functionName) + i.remove() + } + if (functionName in nonDelegatingFns) { + i.remove() + } + } + + // If there are function names remaining in `delegatorFns`, then there are tests missing. + // Only report the names of the functions not tested: parameters are not matched, so any + // signatures could be misleading. + assertThat(delegatorFns) + .describedAs { + "Delegation is not tested for functions ${delegatorFns.map { it.name }}." + } + .isEmpty() + + // If there are function names remaining in `testedFns`, then there are more tests than + // there should be. Functions might be tested twice, or there may be tests for functions + // that have since been removed from the delegate (though those tests probably failed). + assertThat(testedFns) + .describedAs { "Unexpected or redundant tests for functions $testedFns." } + .isEmpty() + } + + @ParameterizedTest + @MethodSource("builderDelegationTestCases") + fun `delegation of Builder write functions`(testCase: DelegationWriteTestCase) { + checkOneDelegationWrite(builderDelegator, mockBuilderDelegate, testCase) + } + + @Test + fun `delegation of responseFormat`() { + // Special unit test case as the delegator method signature does not match that of the + // delegate method. + val delegatorTestCase = DelegationWriteTestCase("responseFormat", X::class.java) + val delegatorMethod = findDelegationMethod(builderDelegator, delegatorTestCase) + val mockDelegateTestCase = + DelegationWriteTestCase("responseFormat", fromClass(X::class.java)) + val mockDelegateMethod = findDelegationMethod(mockBuilderDelegate, mockDelegateTestCase) + + delegatorMethod.invoke(builderDelegator, delegatorTestCase.inputValues[0]) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockBuilderDelegate, times(1)).apply { + mockDelegateMethod.invoke(mockBuilderDelegate, mockDelegateTestCase.inputValues[0]) + } + verifyNoMoreInteractions(mockBuilderDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt new file mode 100644 index 00000000..347788a3 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt @@ -0,0 +1,141 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.DelegationReadTestCase +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_FIELD +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_VALUE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.MESSAGE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.OPTIONAL +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.X +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.checkOneDelegationRead +import java.util.Optional +import kotlin.reflect.full.declaredFunctions +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletionMessage] class (delegator) and its delegation of most + * functions to a wrapped [ChatCompletionMessage] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionMessageTest { + companion object { + // The list order follows the declaration order in `StructuredChatCompletionMessage` for + // easier maintenance. See `StructuredChatCompletionTest` for details on the values used. + @JvmStatic + fun delegationTestCases() = + listOf( + // `content()` is a special case and has its own test function. + DelegationReadTestCase("refusal", OPTIONAL), + DelegationReadTestCase("_role", JSON_VALUE), + DelegationReadTestCase("annotations", OPTIONAL), + DelegationReadTestCase("audio", OPTIONAL), + DelegationReadTestCase("functionCall", OPTIONAL), + DelegationReadTestCase("toolCalls", OPTIONAL), + // `_content()` is a special case and has its own test function. + DelegationReadTestCase("_refusal", JSON_FIELD), + DelegationReadTestCase("_annotations", JSON_FIELD), + DelegationReadTestCase("_audio", JSON_FIELD), + DelegationReadTestCase("_functionCall", JSON_FIELD), + DelegationReadTestCase("_toolCalls", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + DelegationReadTestCase("validate", MESSAGE), + // For this boolean function, call with both possible values to ensure that any + // hard-coding or default value will not result in a false positive test. + DelegationReadTestCase("isValid", true), + DelegationReadTestCase("isValid", false), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + val mockDelegate: ChatCompletionMessage = mock(ChatCompletionMessage::class.java) + val delegator = StructuredChatCompletionMessage(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + StructuredChatCompletionTest.checkAllDelegation( + ChatCompletionMessage::class, + StructuredChatCompletionMessage::class, + "toBuilder", + "toParam", + ) + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + StructuredChatCompletionTest.checkAllDelegation( + StructuredChatCompletionMessage::class, + ChatCompletionMessage::class, + ) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. + val exceptionalTestedFns = setOf("content", "_content") + val testedFns = delegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns + // A few delegator functions do not delegate, so no test function is necessary. + val nonDelegatingFns = listOf("equals", "hashCode", "toString") + + val delegatorFunctions = StructuredChatCompletionMessage::class.declaredFunctions + + for (delegatorFunction in delegatorFunctions) { + assertThat( + delegatorFunction.name in testedFns || + delegatorFunction.name in nonDelegatingFns + ) + .describedAs("Delegation is not tested for function '${delegatorFunction.name}.") + .isTrue + } + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of content`() { + // Input and output are different types, so this test is an exceptional case. + // `content()` (without an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of("{\"s\" : \"hello\"}") + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator.content() // Without an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isEqualTo(Optional.of(X("hello"))) + } + + @Test + fun `delegation of _content`() { + // Input and output are different types, so this test is an exceptional case. + // `_content()` delegates to `_content()` indirectly via the `content` field initializer. + val input = JsonField.of("{\"s\" : \"hello\"}") + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator._content() // With an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isEqualTo(JsonField.of(X("hello"))) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt new file mode 100644 index 00000000..af380bbf --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt @@ -0,0 +1,405 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import java.util.Optional +import kotlin.reflect.KClass +import kotlin.reflect.KVisibility +import kotlin.reflect.full.declaredFunctions +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletion] class (delegator) and its delegation of most + * functions to a wrapped [ChatCompletion] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionTest { + companion object { + internal fun checkAllDelegation( + delegateClass: KClass<*>, + delegatorClass: KClass<*>, + vararg exceptFunctionNames: String, + ) { + assertThat(delegateClass != delegatorClass) + .describedAs { "Delegate and delegator classes should not be the same." } + .isTrue + + val delegateFunctions = delegateClass.declaredFunctions + + for (delegateFunction in delegateFunctions) { + if (delegateFunction.visibility != KVisibility.PUBLIC) { + // Non-public methods are just implementation details of each class. + continue + } + + if (delegateFunction.name in exceptFunctionNames) { + // Ignore functions that are known exceptions (e.g., `toBuilder`). + continue + } + + // Drop the first parameter from each function, as it is the implicit "this" object + // and has the type of the class declaring the function, which will never match. + val delegatorFunction = + delegatorClass.declaredFunctions.find { + it.name == delegateFunction.name && + it.parameters.drop(1).map { it.type } == + delegateFunction.parameters.drop(1).map { it.type } + } + + assertThat(delegatorFunction != null) + .describedAs { + "Function $delegateFunction is not found in ${delegatorClass.simpleName}." + } + .isTrue + } + } + + internal fun checkOneDelegationRead( + delegator: Any, + mockDelegate: Any, + testCase: DelegationReadTestCase, + ) { + // Stub the method in the mock delegate using reflection + val delegateMethod = mockDelegate::class.java.getMethod(testCase.functionName) + `when`(delegateMethod.invoke(mockDelegate)).thenReturn(testCase.expectedValue) + + // Call the corresponding method on the delegator using reflection + val delegatorMethod = delegator::class.java.getMethod(testCase.functionName) + val result = delegatorMethod.invoke(delegator) + + // Verify that the corresponding method on the mock delegate was called exactly once + verify(mockDelegate, times(1)).apply { delegateMethod.invoke(mockDelegate) } + verifyNoMoreInteractions(mockDelegate) + + // Assert that the result matches the expected value + assertThat(result).isEqualTo(testCase.expectedValue) + } + + // Where a function returns `Optional`, `JsonField` or `JsonValue` There is no need to + // provide a value that matches the type ``, a simple `String` value of `"a-string"` will + // work OK with the test. Constants have been provided for this purpose. + internal const val STRING = "a-string" + + internal val OPTIONAL = Optional.of(STRING) + internal val JSON_FIELD = JsonField.of(STRING) + internal val JSON_VALUE = JsonValue.from(STRING) + internal val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + private val FINISH_REASON = ChatCompletion.Choice.FinishReason.STOP + private val CHOICE = + ChatCompletion.Choice.builder() + .message(MESSAGE) + .index(0L) + .finishReason(FINISH_REASON) + .logprobs( + ChatCompletion.Choice.Logprobs.builder().content(null).refusal(null).build() + ) + .build() + + data class DelegationReadTestCase(val functionName: String, val expectedValue: Any) + + // The list order follows the declaration order in `StructuredChatCompletionMessage` for + // easier maintenance. + @JvmStatic + fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + // `choices()` is a special case and has its own test function. + DelegationReadTestCase("created", 123L), + DelegationReadTestCase("model", STRING), + DelegationReadTestCase("_object_", JSON_VALUE), + DelegationReadTestCase("serviceTier", OPTIONAL), + DelegationReadTestCase("systemFingerprint", OPTIONAL), + DelegationReadTestCase("usage", OPTIONAL), + DelegationReadTestCase("_id", JSON_FIELD), + // `_choices()` is a special case and has its own test function. + DelegationReadTestCase("_created", JSON_FIELD), + DelegationReadTestCase("_model", JSON_FIELD), + DelegationReadTestCase("_serviceTier", JSON_FIELD), + DelegationReadTestCase("_systemFingerprint", JSON_FIELD), + DelegationReadTestCase("_usage", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + + @JvmStatic + fun choiceDelegationTestCases() = + listOf( + DelegationReadTestCase("finishReason", FINISH_REASON), + DelegationReadTestCase("index", 123L), + DelegationReadTestCase("logprobs", OPTIONAL), + DelegationReadTestCase("_finishReason", JSON_FIELD), + // `message()` is a special case and has its own test function. + DelegationReadTestCase("_index", JSON_FIELD), + DelegationReadTestCase("_logprobs", JSON_FIELD), + // `_message()` is a special case and has its own test function. + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + + /** A basic class used as the generic type when testing. */ + internal class X(val s: String) { + override fun equals(other: Any?) = other is X && other.s == s + + override fun hashCode() = s.hashCode() + } + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + val mockDelegate: ChatCompletion = mock(ChatCompletion::class.java) + val delegator = StructuredChatCompletion(X::class.java, mockDelegate) + + val mockChoiceDelegate: ChatCompletion.Choice = mock(ChatCompletion.Choice::class.java) + val choiceDelegator = StructuredChatCompletion.Choice(X::class.java, mockChoiceDelegate) + + @Test + fun allChatCompletionDelegateFunctionsExistInDelegator() { + checkAllDelegation(ChatCompletion::class, StructuredChatCompletion::class, "toBuilder") + } + + @Test + fun allChatCompletionDelegatorFunctionsExistInDelegate() { + checkAllDelegation(StructuredChatCompletion::class, ChatCompletion::class) + } + + @Test + fun allChoiceDelegateFunctionsExistInDelegator() { + checkAllDelegation( + ChatCompletion.Choice::class, + StructuredChatCompletion.Choice::class, + "toBuilder", + ) + } + + @Test + fun allChoiceDelegatorFunctionsExistInDelegate() { + checkAllDelegation(StructuredChatCompletion.Choice::class, ChatCompletion.Choice::class) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. + val exceptionalTestedFns = setOf("choices", "_choices", "validate", "isValid") + val testedFns = delegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns + // A few delegator functions do not delegate, so no test function is necessary. + val nonDelegatingFns = listOf("equals", "hashCode", "toString") + + val delegatorFunctions = StructuredChatCompletion::class.declaredFunctions + + for (delegatorFunction in delegatorFunctions) { + assertThat( + delegatorFunction.name in testedFns || + delegatorFunction.name in nonDelegatingFns + ) + .describedAs("Delegation is not tested for function '${delegatorFunction.name}.") + .isTrue + } + } + + @Test + fun allChoiceDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. + val exceptionalTestedFns = setOf("message", "_message", "validate", "isValid") + val testedFns = + choiceDelegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns + // A few delegator functions do not delegate, so no test function is necessary. + val nonDelegatingFns = listOf("equals", "hashCode", "toString") + + val delegatorFunctions = StructuredChatCompletion.Choice::class.declaredFunctions + + for (delegatorFunction in delegatorFunctions) { + assertThat( + delegatorFunction.name in testedFns || + delegatorFunction.name in nonDelegatingFns + ) + .describedAs( + "Delegation is not tested for function 'Choice.${delegatorFunction.name}." + ) + .isTrue + } + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @ParameterizedTest + @MethodSource("choiceDelegationTestCases") + fun `delegation of Choice functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(choiceDelegator, mockChoiceDelegate, testCase) + } + + @Test + fun `delegation of choices`() { + // Input and output are different types, so this test is an exceptional case. + // `choices()` (without an underscore) delegates to `_choices()` (with an underscore) + // indirectly via the `choices` field initializer. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.choices() // Without an underscore. + + verify(mockDelegate, times(1))._choices() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].choice).isEqualTo(CHOICE) + } + + @Test + fun `delegation of _choices`() { + // Input and output are different types, so this test is an exceptional case. + // `_choices()` delegates to `_choices()` indirectly via the `choices` field initializer. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator._choices() // With an underscore. + + verify(mockDelegate, times(1))._choices() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("_choices")[0].choice).isEqualTo(CHOICE) + } + + @Test + fun `delegation of validate`() { + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.validate() + + // `validate()` calls `choices()` on the delegator which triggers the lazy initializer which + // calls `_choices()` on the delegate before `validate()` also calls `validate()` on the + // delegate. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isSameAs(delegator) + } + + @Test + fun `delegation of isValid when true`() { + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of isValid when false`() { + // Try with a `false` value to make sure `isValid()` is not just hard-coded to `true`. Do + // this by making `validate()` on the delegate throw an exception. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + `when`(mockDelegate.validate()).thenThrow(OpenAIInvalidDataException("test")) + val output = delegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isFalse + } + + @Test + fun `delegation of Choice-message`() { + // Input and output are different types, so this test is an exceptional case. + // `message()` (without an underscore) delegates to `_message()` (with an underscore) + // indirectly via the `message` field initializer. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.message() // Without an underscore. + + verify(mockChoiceDelegate, times(1))._message() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output.chatCompletionMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of Choice-_message`() { + // Input and output are different types, so this test is an exceptional case. + // `_message()` delegates to `_message()` indirectly via the `message` field initializer. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator._message() // With an underscore. + + verify(mockChoiceDelegate, times(1))._message() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output.getRequired("_message").chatCompletionMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of Choice-validate`() { + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.validate() + + // `validate()` calls `message()` on the delegator which triggers the lazy initializer which + // calls `_message()` on the delegate before `validate()` also calls `validate()` on the + // delegate. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isSameAs(choiceDelegator) + } + + @Test + fun `delegation of Choice-isValid when true`() { + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of Choice-isValid when false`() { + // Try with a `false` value to make sure `isValid()` is not just hard-coded to `true`. Do + // this by making `validate()` on the delegate throw an exception. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + `when`(mockChoiceDelegate.validate()).thenThrow(OpenAIInvalidDataException("test")) + val output = choiceDelegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isFalse + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java new file mode 100644 index 00000000..bcc46a80 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java @@ -0,0 +1,73 @@ +package com.openai.example; + +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; +import java.util.List; + +public final class StructuredOutputsClassExample { + + public static class Person { + public String firstName; + public String surname; + + @JsonPropertyDescription("The date of birth of the person.") + public String dateOfBirth; + + @Override + public String toString() { + return "Person{firstName=" + firstName + ", surname=" + surname + ", dateOfBirth=" + dateOfBirth + '}'; + } + } + + public static class Laureate { + public Person person; + public String majorAchievement; + public int yearWon; + + @JsonPropertyDescription("The share of the prize money won by the Nobel Laureate.") + public double prizeMoney; + + @Override + public String toString() { + return "Laureate{person=" + + person + ", majorAchievement=" + + majorAchievement + ", yearWon=" + + yearWon + ", prizeMoney=" + + prizeMoney + '}'; + } + } + + public static class Laureates { + @JsonPropertyDescription("A list of winners of a Nobel Prize.") + public List laureates; + + @Override + public String toString() { + return "Laureates{laureates=" + laureates + '}'; + } + } + + private StructuredOutputsClassExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .maxCompletionTokens(2048) + .responseFormat(Laureates.class) + .addUserMessage("List some winners of the Nobel Prize in Physics since 2000.") + .build(); + + client.chat().completions().create(createParams).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .forEach(System.out::println); + } +} From bc3c47926e7c598d48e1eb85fcfacf85ccfb4c9a Mon Sep 17 00:00:00 2001 From: D Gardner Date: Fri, 2 May 2025 16:40:52 +0100 Subject: [PATCH 2/9] structured-outputs: repair after bad merge. --- .../com/openai/core/JsonSchemaValidator.kt | 670 ++++++++++++++++++ .../openai/core/JsonSchemaValidatorTest.kt | 3 +- 2 files changed, 672 insertions(+), 1 deletion(-) create mode 100644 openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt new file mode 100644 index 00000000..6af40929 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt @@ -0,0 +1,670 @@ +package com.openai.core + +import com.fasterxml.jackson.databind.JsonNode +import com.openai.core.JsonSchemaValidator.Companion.MAX_ENUM_TOTAL_STRING_LENGTH +import com.openai.core.JsonSchemaValidator.Companion.UNRESTRICTED_ENUM_VALUES_LIMIT + +/** + * A validator that ensures that a JSON schema complies with the rules and restrictions imposed by + * the OpenAI API specification for the input schemas used to define structured outputs. Only a + * subset of the JSON Schema language is supported. The purpose of this validator is to perform a + * quick check of a schema so that it can be determined to be likely to be accepted when passed in + * the request for an AI inference. + * + * This validator assumes that the JSON schema represents the structure of Java/Kotlin classes; it + * is not a general-purpose JSON schema validator. Assumptions are also made that the generator will + * be well-behaved, so the validation is not a check for strict conformance to the JSON Schema + * specification, but to the OpenAI API specification's restrictions on JSON schemas. + */ +internal class JsonSchemaValidator private constructor() { + + companion object { + // The names of the supported schema keywords. All other keywords will be rejected. + private const val SCHEMA = "\$schema" + private const val ID = "\$id" + private const val DEFS = "\$defs" + private const val REF = "\$ref" + private const val PROPS = "properties" + private const val ANY_OF = "anyOf" + private const val TYPE = "type" + private const val REQUIRED = "required" + private const val DESC = "description" + private const val TITLE = "title" + private const val ITEMS = "items" + private const val CONST = "const" + private const val ENUM = "enum" + private const val ADDITIONAL_PROPS = "additionalProperties" + + // The names of the supported schema data types. + // + // JSON Schema does not define an "integer" type, only a "number" type, but it allows any + // schema to define its own "vocabulary" of type names. "integer" is supported by OpenAI. + private const val TYPE_ARRAY = "array" + private const val TYPE_OBJECT = "object" + private const val TYPE_BOOLEAN = "boolean" + private const val TYPE_STRING = "string" + private const val TYPE_NUMBER = "number" + private const val TYPE_INTEGER = "integer" + private const val TYPE_NULL = "null" + + // The validator checks that unsupported type-specific keywords are not present in a + // property node. The OpenAI API specification states: + // + // "Notable keywords not supported include: + // + // - For strings: `minLength`, `maxLength`, `pattern`, `format` + // - For numbers: `minimum`, `maximum`, `multipleOf` + // - For objects: `patternProperties`, `unevaluatedProperties`, `propertyNames`, + // `minProperties`, `maxProperties` + // - For arrays: `unevaluatedItems`, `contains`, `minContains`, `maxContains`, `minItems`, + // `maxItems`, `uniqueItems`" + // + // As that list is not exhaustive, and no keywords are explicitly named as supported, this + // validation allows _no_ type-specific keywords. The following sets define the allowed + // keywords in different contexts and all others are rejected. + + /** + * The set of allowed keywords in the root schema only, not including the keywords that are + * also allowed in a sub-schema. + */ + private val ALLOWED_KEYWORDS_ROOT_SCHEMA_ONLY = setOf(SCHEMA, ID, DEFS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"anyOf"` field is + * present. OpenAI allows the `"anyOf"` field in sub-schemas, but not in the root schema. + */ + private val ALLOWED_KEYWORDS_ANY_OF_SUB_SCHEMA = setOf(ANY_OF, TITLE, DESC) + + /** + * The set of allowed keywords when defining sub-schemas when the `"$ref"` field is present. + */ + private val ALLOWED_KEYWORDS_REF_SUB_SCHEMA = setOf(REF, TITLE, DESC) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"object"`. + */ + private val ALLOWED_KEYWORDS_OBJECT_SUB_SCHEMA = + setOf(TYPE, REQUIRED, ADDITIONAL_PROPS, TITLE, DESC, PROPS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"array"`. + */ + private val ALLOWED_KEYWORDS_ARRAY_SUB_SCHEMA = setOf(TYPE, TITLE, DESC, ITEMS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"boolean"`, `"integer"`, `"number"`, or `"string"`. + */ + private val ALLOWED_KEYWORDS_SIMPLE_SUB_SCHEMA = setOf(TYPE, TITLE, DESC, ENUM, CONST) + + /** + * The maximum total length of all strings used in the schema for property names, definition + * names, enum values and const values. The OpenAI specification states: + * > In a schema, total string length of all property names, definition names, enum values, + * > and const values cannot exceed 15,000 characters. + */ + private const val MAX_TOTAL_STRING_LENGTH = 15_000 + + /** The maximum number of object properties allowed in a schema. */ + private const val MAX_OBJECT_PROPERTIES = 100 + + /** The maximum number of enum values across all enums in the schema. */ + private const val MAX_ENUM_VALUES = 500 + + /** + * The number of enum values in any one enum with string values beyond which a limit of + * [MAX_ENUM_TOTAL_STRING_LENGTH] is imposed on the total length of all the string values of + * that one enum. + */ + private const val UNRESTRICTED_ENUM_VALUES_LIMIT = 250 + + /** + * The maximum total length of all string values of a single enum where the number of values + * exceeds [UNRESTRICTED_ENUM_VALUES_LIMIT]. + */ + private const val MAX_ENUM_TOTAL_STRING_LENGTH = 7_500 + + /** The maximum depth (number of levels) of nesting allowed in a schema. */ + private const val MAX_NESTING_DEPTH = 5 + + /** The depth value that corresponds to the root level of the schema. */ + private const val ROOT_DEPTH = 0 + + /** + * The path string that identifies the root node in the schema when appearing in error + * messages or references. + */ + private const val ROOT_PATH = "#" + + /** + * Creates a new [JsonSchemaValidator]. After calling [validate], the validator instance + * holds information about the errors that occurred during validation (if any). A validator + * instance can be used only once to validate a schema; to validate another schema, create + * another validator. + */ + fun create() = JsonSchemaValidator() + } + + /** + * The total length of all strings used in the schema for property names, definition names, enum + * values and const values. + */ + private var totalStringLength: Int = 0 + + /** The total number of values across all enums in the schema. */ + private var totalEnumValues: Int = 0 + + /** The total number of object properties found in the schema, including in definitions. */ + private var totalObjectProperties: Int = 0 + + /** + * The set of valid references that may appear in the schema. This set includes the root schema + * and any definitions within the root schema. This is used to verify that references elsewhere + * in the schema are valid. This will always contain the root schema, but that may be the only + * member. + */ + private var validReferences: MutableSet = mutableSetOf(ROOT_PATH) + + /** The list of error messages accumulated during the validation process. */ + private val errors: MutableList = mutableListOf() + + /** + * Indicates if this validator has validated a schema or not. If a schema has been validated, + * this validator cannot be used again. + */ + private var isValidationComplete = false + + /** + * Gets the list of errors that were recorded during the validation pass. + * + * @return The list of errors. The list may be empty if no errors were recorded. In that case, + * the schema was found to be valid, or has not yet been validated by calling [validate]. + */ + fun errors(): List = errors.toImmutable() + + /** + * Indicates if a validated schema is valid or not. + * + * @return `true` if a schema has been validated by calling [validate] and no errors were + * reported; or `false` if errors were reported or if a schema has not yet been validated. + */ + fun isValid(): Boolean = isValidationComplete && errors.isEmpty() + + /** + * Validates a schema with respect to the OpenAI API specifications. + * + * @param rootSchema The root node of the tree representing the JSON schema definition. + * @return This schema validator for convenience, such as when chaining calls. + * @throws IllegalStateException If called a second time. Create a new validator to validate + * each new schema. + */ + fun validate(rootSchema: JsonNode): JsonSchemaValidator { + if (isValidationComplete) { + throw IllegalStateException("Validation already complete.") + } + isValidationComplete = true + + validateSchema(rootSchema, ROOT_PATH, ROOT_DEPTH) + + // Verify total counts/lengths. These are not localized to a specific element in the schema, + // as no one element is the cause of the error; it is the combination of all elements that + // exceed the limits. Therefore, the root path is used in the error messages. + verify(totalEnumValues <= MAX_ENUM_VALUES, ROOT_PATH) { + "Total number of enum values ($totalEnumValues) exceeds limit of $MAX_ENUM_VALUES." + } + verify(totalStringLength <= MAX_TOTAL_STRING_LENGTH, ROOT_PATH) { + "Total string length of all values ($totalStringLength) exceeds " + + "limit of $MAX_TOTAL_STRING_LENGTH." + } + verify(totalObjectProperties <= MAX_OBJECT_PROPERTIES, ROOT_PATH) { + "Total number of object properties ($totalObjectProperties) exceeds " + + "limit of $MAX_OBJECT_PROPERTIES." + } + + return this + } + + /** + * Validates a schema. This may be the root schema or a sub-schema. Some validations are + * specific to the root schema, which is identified by the [depth] being equal to zero. + * + * This method is recursive: it will validate the given schema and any sub-schemas that it + * contains at any depth. References to other schemas (either the root schema or definition + * sub-schemas) do not increase the depth of nesting, as those references are not followed + * recursively, only checked to be valid internal schema references. + * + * @param schema The schema to be validated. This may be the root schema or any sub-schema. + * @param path The path that identifies the location of this schema within the JSON schema. For + * example, for the root schema, this will be `"#"`; for a definition sub-schema of a `Person` + * object, this will be `"#/$defs/Person"`. + * @param depth The current depth of nesting. The OpenAI API specification places a maximum + * limit on the depth of nesting, which will result in an error if it is exceeded. The nesting + * depth increases with each recursion into a nested sub-schema. For the root schema, the + * nesting depth is zero; all other sub-schemas will have a nesting depth greater than zero. + */ + private fun validateSchema(schema: JsonNode, path: String, depth: Int) { + verify(depth <= MAX_NESTING_DEPTH, path) { + "Current nesting depth is $depth, but maximum is $MAX_NESTING_DEPTH." + } + + verify(schema.isObject, path, { "Schema or sub-schema is not an object." }) { + // If the schema is not an object, perform no further validations. + return + } + + verify(!schema.isEmpty, path) { "Schema or sub-schema is empty." } + + if (depth == ROOT_DEPTH) { + // Sanity check for the presence of the "$schema" field, as this makes it more likely + // that the schema with `depth == 0` is actually the root node of a JSON schema, not + // just a generic JSON document that is being validated in error. + verify(schema.get(SCHEMA) != null, path) { "Root schema missing '$SCHEMA' field." } + } + + // Before sub-schemas can be validated, the list of definitions must be recorded to ensure + // that "$ref" references can be checked for validity. Definitions are optional and only + // appear in the root schema. + validateDefinitions(schema.get(DEFS), "$path/$DEFS", depth) + + val anyOf = schema.get(ANY_OF) + val type = schema.get(TYPE) + val ref = schema.get(REF) + + verify( + (anyOf != null).xor(type != null).xor(ref != null), + path, + { "Expected exactly one of '$TYPE' or '$ANY_OF' or '$REF'." }, + ) { + // Validation cannot continue if none are set, or if more than one is set. + return + } + + validateAnyOfSchema(schema, path, depth) + validateTypeSchema(schema, path, depth) + validateRefSchema(schema, path, depth) + } + + /** + * Validates a schema if it has an `"anyOf"` field. OpenAI does not support the use of `"anyOf"` + * at the root of a JSON schema. The value is the field is expected to be an array of valid + * sub-schemas. If the schema has no `"anyOf"` field, no action is taken. + */ + private fun validateAnyOfSchema(schema: JsonNode, path: String, depth: Int) { + val anyOf = schema.get(ANY_OF) + + if (anyOf == null) return + + validateKeywords(schema, ALLOWED_KEYWORDS_ANY_OF_SUB_SCHEMA, path, depth) + + verify( + anyOf.isArray && !anyOf.isEmpty, + path, + { "'$ANY_OF' field is not a non-empty array." }, + ) { + return + } + + // Validates that the root schema does not contain an `anyOf` field. This is a restriction + // imposed by the OpenAI API specification. `anyOf` fields _can_ appear at other depths. + verify(depth != ROOT_DEPTH, path) { "Root schema contains '$ANY_OF' field." } + + // Each entry must be a valid sub-schema. + anyOf.forEachIndexed { index, subSchema -> + validateSchema(subSchema, "$path/$ANY_OF[$index]", depth + 1) + } + } + + /** + * Validates a schema if it has a `"$ref"` field. The reference is checked to ensure it + * corresponds to a valid definition, or is a reference to the root schema. Recursive references + * are allowed. If no `"$ref"` field is found in the schema, no action is taken. + */ + private fun validateRefSchema(schema: JsonNode, path: String, depth: Int) { + val ref = schema.get(REF) + + if (ref == null) return + + validateKeywords(schema, ALLOWED_KEYWORDS_REF_SUB_SCHEMA, path, depth) + + val refPath = "$path/$REF" + + verify(ref.isTextual, refPath, { "'$REF' field is not a text value." }) { + // No point checking the reference has a referent if it is definitely malformed. + return + } + verify(ref.asText() in validReferences, refPath) { + "Invalid or unsupported reference: '${ref.asText()}'." + } + } + + /** + * Validates a schema if it has a `"type"` field. This includes most sub-schemas, except those + * that have a `"$ref"` or `"anyOf"` field instead. The `"type"` field may be set to a text + * value that is the name of the type (e.g., `"object"`, `"array"`, `"number"`), or it may be + * set to an array that contains two text values: the name of the type and `"null"`. The OpenAI + * API specification explains that this is how a property can be both required (i.e., it must + * appear in the JSON document), but its value can be optional (i.e., it can be set explicitly + * to `"null"`). If the schema has no `"type"` field, no action is taken. + */ + private fun validateTypeSchema(schema: JsonNode, path: String, depth: Int) { + val type = schema.get(TYPE) + + if (type == null) return + + val typeName = + if (type.isTextual) { + // Type will be something like `"type" : "string"` + type.asText() + } else if (type.isArray) { + // Type will be something like `"type" : [ "string", "null" ]`. This corresponds to + // the use of "Optional" in Java/Kotlin. + getTypeNameFromTypeArray(type, "$path/$TYPE") + } else { + error(path) { "'$TYPE' field is not a type name or array of type names." } + return + } + + when (typeName) { + TYPE_ARRAY -> validateArraySchema(schema, path, depth) + TYPE_OBJECT -> validateObjectSchema(schema, path, depth) + + TYPE_BOOLEAN, + TYPE_INTEGER, + TYPE_NUMBER, + TYPE_STRING -> validateSimpleSchema(schema, typeName, path, depth) + + // The type name could not be determined from a type name array. An error will already + // have been logged by `getTypeNameFromTypeArray`, so no need to do anything more here. + null -> return + + else -> error("$path/$TYPE") { "Unsupported '$TYPE' value: '$typeName'." } + } + } + + /** + * Validates a schema whose `"type"` is `"object"`. It is the responsibility of the caller to + * ensure that [schema] contains that type definition. If no type, or a different type is + * present, the behavior is not defined. + */ + private fun validateObjectSchema(schema: JsonNode, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_OBJECT_SUB_SCHEMA, path, depth) + + // The schema must declare that additional properties are not allowed. For this check, it + // does not matter if there are no "properties" in the schema. + verify( + schema.get(ADDITIONAL_PROPS) != null && + schema.get(ADDITIONAL_PROPS).asBoolean() == false, + path, + ) { + "'$ADDITIONAL_PROPS' field is missing or is not set to 'false'." + } + + val properties = schema.get(PROPS) + + // The "properties" field may be missing (there may be no properties to declare), but if it + // is present, it must be a non-empty object, or validation cannot continue. + // TODO: Decide if a missing or empty "properties" field is OK or not. + verify( + properties == null || (properties.isObject && !properties.isEmpty), + path, + { "'$PROPS' field is not a non-empty object." }, + ) { + return + } + + if (properties != null) { // Must be an object. + // If a "properties" field is present, there must also be a "required" field. All + // properties must be named in the list of required properties. + validatePropertiesRequired( + properties.fieldNames().asSequence().toSet(), + schema.get(REQUIRED), + "$path/$REQUIRED", + ) + validateProperties(properties, "$path/$PROPS", depth) + } + } + + /** + * Validates a schema whose `"type"` is `"array"`. It is the responsibility of the caller to + * ensure that [schema] contains that type definition. If no type, or a different type is + * present, the behavior is not defined. + * + * An array schema must have an `"items"` field whose value is an object representing a valid + * sub-schema. + */ + private fun validateArraySchema(schema: JsonNode, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_ARRAY_SUB_SCHEMA, path, depth) + + val items = schema.get(ITEMS) + + verify( + items != null && items.isObject, + path, + { "'$ITEMS' field is missing or is not an object." }, + ) { + return + } + + validateSchema(items, "$path/$ITEMS", depth + 1) + } + + /** + * Validates a schema whose `"type"` is one of the supported simple type names other than + * `"object"` and `"array"`. It is the responsibility of the caller to ensure that [schema] + * contains the correct type definition. If no type, or a different type is present, the + * behavior is not defined. + * + * @param typeName The name of the specific type of the schema. Where the field value is + * optional and the type is defined as an array of a type name and a `"null"`, this is the + * value of the non-`"null"` type name. For example `"string"`, or `"number"`. + */ + private fun validateSimpleSchema(schema: JsonNode, typeName: String, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_SIMPLE_SUB_SCHEMA, path, depth) + + val enumField = schema.get(ENUM) + + // OpenAI API specification: "For a single enum property with string values, the total + // string length of all enum values cannot exceed 7,500 characters when there are more than + // 250 enum values." + val isString = typeName == TYPE_STRING + var numEnumValues = 0 + var stringLength = 0 + + enumField?.forEach { value -> + // OpenAI places limits on the total string length of all enum values across all enums + // without being specific about the type of those enums (unlike for enums with string + // values, which have their own restrictions noted above). The specification does not + // indicate how to count the string length for boolean or number values. Here it is + // assumed that their simple string representations should be counted. + val length = value.asText().length + + totalStringLength += length + totalEnumValues++ + + if (isString) { + numEnumValues++ + stringLength += length + } + } + + verify( + !isString || + numEnumValues <= UNRESTRICTED_ENUM_VALUES_LIMIT || + stringLength <= MAX_ENUM_TOTAL_STRING_LENGTH, + "$path/$ENUM", + ) { + "Total string length ($stringLength) of values of an enum with $numEnumValues " + + "values exceeds limit of $MAX_ENUM_TOTAL_STRING_LENGTH." + } + + schema.get(CONST)?.let { constValue -> totalStringLength += constValue.asText().length } + } + + /** + * Validates that the definitions (if present) contain fields that each define a valid schema. + * Records the names of any definitions to construct the set of possible valid references to + * those definitions. This set will be used to validate any references from within definition + * sub-schemas, or any other sub-schemas validated at a later time. + * + * @param defs The node containing the definitions. Definitions are optional, so this node may + * be `null`. Definitions may appear in the root schema, but will not appear in any + * sub-schemas. If no definitions are present, the list of valid references will not be + * changed and no errors will be recorded. + * @param path The path that identifies the location within the schema of the `"$defs"` node. + * @param depth The current depth of nesting. If definitions are present, this will be zero, as + * that is the depth of the root schema. + */ + private fun validateDefinitions(defs: JsonNode?, path: String, depth: Int) { + // Definitions are optional. If present, expect an object whose fields are named from the + // classes the definitions were extracted from. If not present, do not continue. + verify(defs == null || defs.isObject, path, { "'$DEFS' must be an object." }) { + return + } + + // First, record the valid references to definitions, as any definition sub-schema may + // contain a reference to any other definitions sub-schema (including itself) and those + // references need to be validated. + defs?.fieldNames()?.asSequence()?.forEach { defName -> + val reference = "$path/$defName" + + // Consider that there might be duplicate definition names if two different classes + // (from different packages) have the same simple name. That would be an error, but + // there is no need to stop the validations. + // TODO: How should duplicate names be handled? Will the generator use longer names? + verify(reference !in validReferences, path) { "Duplicate definition of '$defName'." } + validReferences += reference + } + + // Second, recursively validate the definition sub-schemas. + defs?.fieldNames()?.asSequence()?.forEach { defName -> + totalStringLength += defName.length + validateSchema(defs.get(defName), "$path/$DEFS/$defName", depth + 1) + } + } + + /** + * Validates that every property in a collection of property names appears in the array of + * property names in a `"required"` field. + * + * @param propertyNames The collection of property names to check in the array of required + * properties. This collection will not be empty. + * @param required The `"required"` field. This is expected to be a non-`null` array field with + * a set of property names. + * @param path The path identifying the location of the `"required"` field within the schema. + */ + private fun validatePropertiesRequired( + propertyNames: Collection, + required: JsonNode?, + path: String, + ) { + val requiredNames = required?.map { it.asText() }?.toSet() ?: emptySet() + + propertyNames.forEach { propertyName -> + verify(propertyName in requiredNames, path) { + "'$PROPS' field '$propertyName' is not listed as '$REQUIRED'." + } + } + } + + /** + * Validates that each named entry in the `"properties"` field of an object schema has a value + * that is a valid sub-schema. + * + * @param properties The `"properties"` field node of an object schema. + * @param path The path identifying the location of the `"properties"` field within the schema. + */ + private fun validateProperties(properties: JsonNode, path: String, depth: Int) { + val propertyNames = properties.fieldNames().asSequence().toList() + + propertyNames.forEach { propertyName -> + totalObjectProperties++ + totalStringLength += propertyName.length + validateSchema(properties.get(propertyName), "$path/$propertyName", depth + 1) + } + } + + /** + * Validates that the names of all fields in the given schema node are present in a collection + * of allowed keywords. + * + * @param depth The nesting depth of the [schema] node. If this depth is zero, an additional set + * of allowed keywords will be included automatically for keywords that are allowed to appear + * only at the root level of the schema (e.g., `"$schema"`, `"$defs"`). + */ + private fun validateKeywords( + schema: JsonNode, + allowedKeywords: Collection, + path: String, + depth: Int, + ) { + schema.fieldNames().forEach { keyword -> + verify( + keyword in allowedKeywords || + (depth == ROOT_DEPTH && keyword in ALLOWED_KEYWORDS_ROOT_SCHEMA_ONLY), + path, + ) { + "Use of '$keyword' is not supported here." + } + } + } + + /** + * Gets the name of a type from the given `"type"` field, where the field is an array that + * contains exactly two string values: a type name and a `"null"` (in any order). + * + * @param type The type node. This must be a field with an array value. If this is not an array + * field, the behavior is undefined. It is the responsibility of the caller to ensure that + * this function is only called for array fields. + * @return The type name in the array that is not the `"null"` type; or `null` if no such type + * name was found, or if the array does not contain exactly two expected values: the type name + * and a `"null"` type. If `null`, one or more validation errors will be recorded. + */ + private fun getTypeNameFromTypeArray(type: JsonNode, path: String): String? { + val types = type.asSequence().toList() + + if (types.size == 2 && types.all { it.isTextual }) { + // Allow one type name and one "null". Be lenient about the order. Assume that there are + // no oddities like type names that are empty strings, etc., as the schemas are expected + // to be generated. + if (types[1].asText() == TYPE_NULL && types[0].asText() != TYPE_NULL) { + return types[0].asText() + } else if (types[0].asText() == TYPE_NULL && types[1].asText() != TYPE_NULL) { + return types[1].asText() + } else { + error(path) { "Expected one type name and one \"$TYPE_NULL\"." } + } + } else { + error(path) { "Expected exactly two types, both strings." } + } + + return null + } + + private inline fun verify(value: Boolean, path: String, lazyMessage: () -> Any) { + verify(value, path, lazyMessage) {} + } + + private inline fun verify( + value: Boolean, + path: String, + lazyMessage: () -> Any, + onFalse: () -> Unit, + ) { + if (!value) { + error(path, lazyMessage) + onFalse() + } + } + + private inline fun error(path: String, lazyMessage: () -> Any) { + errors.add("$path: ${lazyMessage()}") + } + + override fun toString(): String = + "${javaClass.simpleName}{isValidationComplete=$isValidationComplete, " + + "totalStringLength=$totalStringLength, " + + "totalObjectProperties=$totalObjectProperties, " + + "totalEnumValues=$totalEnumValues, errors=$errors}" +} diff --git a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt index 31768c04..ccbc3926 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt @@ -82,7 +82,8 @@ internal class JsonSchemaValidatorTest { assertThat(validator.isValid()).isTrue } - @Test + // FIXME: Disabled test until issues (noted below) are resolved. + // @Test fun schemaTest_minimalListSchema() { val s: List = listOf() From 5ab252cfa07550be1920d67cd01fd87e7d7170db Mon Sep 17 00:00:00 2001 From: D Gardner Date: Mon, 5 May 2025 16:27:46 +0100 Subject: [PATCH 3/9] structured-outputs: local validation, unit tests and documentation --- README.md | 166 +++++++++++++++++- .../com/openai/core/StructuredOutputs.kt | 28 ++- .../completions/ChatCompletionCreateParams.kt | 23 ++- .../StructuredChatCompletionCreateParams.kt | 14 +- ...idatorTest.kt => StructuredOutputsTest.kt} | 103 ++++++++++- .../StructuredOutputsClassExample.java | 10 +- 6 files changed, 314 insertions(+), 30 deletions(-) rename openai-java-core/src/test/kotlin/com/openai/core/{JsonSchemaValidatorTest.kt => StructuredOutputsTest.kt} (92%) diff --git a/README.md b/README.md index 39125ce8..2b8cb881 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() The SDK provides conveniences for streamed chat completions. A [`ChatCompletionAccumulator`](openai-java-core/src/main/kotlin/com/openai/helpers/ChatCompletionAccumulator.kt) -can record the stream of chat completion chunks in the response as they are processed and accumulate +can record the stream of chat completion chunks in the response as they are processed and accumulate a [`ChatCompletion`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletion.kt) object similar to that which would have been returned by the non-streaming API. @@ -334,6 +334,166 @@ client.chat() ChatCompletion chatCompletion = chatCompletionAccumulator.chatCompletion(); ``` +## Structured outputs with JSON schemas + +Open AI [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat) +is a feature that ensures that the model will always generate responses that adhere to a supplied +[JSON schema](https://json-schema.org/overview/what-is-jsonschema). + +A JSON schema can be defined by creating a +[`ResponseFormatJsonSchema`](openai-java-core/src/main/kotlin/com/openai/models/ResponseFormatJsonSchema.kt) +and setting it on the input parameters. However, for greater convenience, a JSON schema can instead +be derived automatically from the structure of an arbitrary Java class. The response will then +automatically convert the generated JSON content to an instance of that Java class. + +Java classes can contain fields declared to be instances of other classes and can use collections: + +```java +class Person { + public String name; + public int yearOfBirth; +} + +class Book { + public String title; + public Person author; + public int yearPublished; +} + +class BookList { + public List books; +} +``` + +Pass the top-level class—`BookList` in this example—to `responseFormat(Class)` when building the +parameters and then access an instance of `BookList` from the generated message content in the +response: + +```java +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addUserMessage("List six famous nineteenth century novels.") + .model(ChatModel.GPT_4_1) + .responseFormat(BookList.class) + .build(); + +client.chat().completions().create(params).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(book.title + " by " + book.author.name)); +``` + +You can start building the parameters with an instance of +[`ChatCompletionCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt) +or +[`StructuredChatCompletionCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt). +If you start with the former (which allows for more compact code) the builder type will change to +the latter when `ChatCompletionCreateParams.Builder.responseFormat(Class)` is called. + +If a field in a class is optional and does not require a defined value, you can represent this using +the [`java.util.Optional`](https://docs.oracle.com/javase/8/docs/api/java/util/Optional.html) class. +It is up to the AI model to decide whether to provide a value for that field or leave it empty. + +```java +import java.util.Optional; + +class Book { + public String title; + public Person author; + public int yearPublished; + public Optional isbn; +} +``` + +If an error occurs while converting a JSON response to an instance of a Java class, the error +message will include the JSON response to assist in diagnosis. For instance, if the response is +truncated, the JSON data will be incomplete and cannot be converted to a class instance. If your +JSON response may contain sensitive information, avoid logging it directly, or ensure that you +redact any sensitive details from the error message. + +### Local JSON schema validation + +Structured Outputs supports a +[subset](https://platform.openai.com/docs/guides/structured-outputs#supported-schemas) of the JSON +Schema language. Schemas are generated automatically from classes to align with this subset. +However, due to the inherent structure of the classes, the generated schema may still violate +certain OpenAI schema restrictions, such as exceeding the maximum nesting depth or utilizing +unsupported data types. + +To facilitate compliance, the method `responseFormat(Class)` performs a validation check on the +schema derived from the specified class. This validation ensures that all restrictions are adhered +to. If any issues are detected, an exception will be thrown, providing a detailed message outlining +the reasons for the validation failure. + +- **Local Validation**: The validation process occurs locally, meaning no requests are sent to the +remote AI model. If the schema passes local validation, it is likely to pass remote validation as +well. +- **Remote Validation**: The remote AI model will conduct its own validation upon receiving the JSON +schema in the request. +- **Version Compatibility**: There may be instances where local validation fails while remote +validation succeeds. This can occur if the SDK version is outdated compared to the restrictions +enforced by the remote model. +- **Disabling Local Validation**: If you encounter compatibility issues and wish to bypass local +validation, you can disable it by passing `false` to the `responseFormat(Class, boolean)` method +when building the parameters. (The default value for this parameter is `true`.) + +```java +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addUserMessage("List six famous nineteenth century novels.") + .model(ChatModel.GPT_4_1) + .responseFormat(BookList.class, false) // Disable local validation. + .build(); +``` + +By following these guidelines, you can ensure that your structured outputs conform to the necessary +schema requirements and minimize the risk of remote validation errors. + +### Annotating classes and JSON schemas + +You can use annotations to add further information to the JSON schema derived from your Java +classes, or to exclude individual fields from the schema. Details from annotations captured in the +JSON schema may be used by the AI model to improve its response. The SDK supports the use of +[Jackson Databind](https://github.com/FasterXML/jackson-databind) annotations. + +```java +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +class Person { + @JsonPropertyDescription("The first name and surname of the person") + public String name; + public int yearOfBirth; +} + +@JsonClassDescription("The details of one published book") +class Book { + public String title; + public Person author; + public int yearPublished; + @JsonIgnore public String genre; +} + +class BookList { + public List books; +} +``` + +- Use `@JsonClassDescription` to add a detailed description to a class. +- Use `@JsonPropertyDescription` to add a detailed description to a field of a class. +- Use `@JsonIgnore` to omit a field of a class from the generated JSON schema. + +If you use `@JsonProperty(required = false)`, the `false` value will be ignored. OpenAI JSON schemas +must mark all properties as _required_, so the schema generated from your Java classes will respect +that restriction and ignore any annotation that would violate it. + ## File uploads The SDK defines methods that accept files. @@ -652,7 +812,7 @@ If the SDK threw an exception, but you're _certain_ the version is compatible, t ## Microsoft Azure -To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the same +To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the same OpenAI client builder but with the Azure-specific configuration. ```java @@ -665,7 +825,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() .build(); ``` -See the complete Azure OpenAI example in the [`openai-java-example`](openai-java-example/src/main/java/com/openai/example/AzureEntraIdExample.java) directory. The other examples in the directory also work with Azure as long as the client is configured to use it. +See the complete Azure OpenAI example in the [`openai-java-example`](openai-java-example/src/main/java/com/openai/example/AzureEntraIdExample.java) directory. The other examples in the directory also work with Azure as long as the client is configured to use it. ## Network options diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index 7f18d237..ba828d9a 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -23,17 +23,37 @@ private val MAPPER = .addModule(JavaTimeModule()) .build() -fun fromClass(type: Class) = - ResponseFormatJsonSchema.builder() +internal fun fromClass( + type: Class, + localValidation: Boolean = true, +): ResponseFormatJsonSchema { + val schema = extractSchema(type) + + if (localValidation) { + val validator = JsonSchemaValidator.create().validate(schema) + + if (!validator.isValid()) { + throw IllegalArgumentException( + "Local validation failed for JSON schema derived from $type:\n" + + validator.errors().joinToString("\n") { " - $it" } + ) + } + } + + return ResponseFormatJsonSchema.builder() .jsonSchema( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.from(extractSchema(type))) + .schema(JsonValue.from(schema)) .build() ) .build() +} internal fun extractSchema(type: Class): JsonNode { + // Validation is not performed by this function, as it allows extraction of the schema and + // validation of the schema to be controlled more easily when unit testing, as no exceptions + // will be thrown and any recorded validation errors can be inspected at leisure by the tests. val configBuilder = SchemaGeneratorConfigBuilder( com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, @@ -56,7 +76,7 @@ internal fun extractSchema(type: Class): JsonNode { return SchemaGenerator(configBuilder.build()).generateSchema(type) } -fun fromJson(json: String, type: Class): T = +internal fun fromJson(json: String, type: Class): T = try { MAPPER.readValue(json, type) } catch (e: Exception) { diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index cb3459fe..0fe22b2b 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -1298,12 +1298,23 @@ private constructor( } /** - * Sets the class that defines the structured outputs response format. This changes the - * builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that will build a - * [StructuredChatCompletionCreateParams] instance when `build()` is called. - */ - fun responseFormat(responseFormat: Class) = - StructuredChatCompletionCreateParams.builder().wrap(responseFormat, this) + * Sets response format to a JSON schema derived from the structure of the given class. This + * changes the builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that + * will build a [StructuredChatCompletionCreateParams] instance when `build()` is called. + * + * @param responseFormat A class from which a JSON schema will be derived to define the + * response format. + * @param localValidation `true` (the default) to validate the JSON schema locally when it + * is generated by this method to confirm that it adheres to the requirements and + * restrictions on JSON schemas imposed by the OpenAI specification; or `false` to disable + * local validation. See the SDK documentation for more details. + * @throws IllegalArgumentException If local validation is enabled, but it fails because a + * valid JSON schema cannot be derived from the given class. + */ + @JvmOverloads + fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = + StructuredChatCompletionCreateParams.builder() + .wrap(responseFormat, this, localValidation) /** * This feature is in Beta. If specified, our system will make a best effort to sample diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt index ae1ea1be..14194ac9 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -33,11 +33,12 @@ internal constructor( internal fun wrap( responseFormat: Class, paramsBuilder: ChatCompletionCreateParams.Builder, + localValidation: Boolean, ) = apply { this.responseFormat = responseFormat this.paramsBuilder = paramsBuilder // Convert the class to a JSON schema and apply it to the delegate `Builder`. - responseFormat(responseFormat) + responseFormat(responseFormat, localValidation) } /** Injects a given `ChatCompletionCreateParams.Builder`. For use only when testing. */ @@ -389,10 +390,15 @@ internal constructor( paramsBuilder.reasoningEffort(reasoningEffort) } - /** Sets the response format to a JSON schema derived from the given class. */ - fun responseFormat(responseFormat: Class) = apply { + /** + * Sets the response format to a JSON schema derived from the structure of the given class. + * + * @see ChatCompletionCreateParams.Builder.responseFormat + */ + @JvmOverloads + fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = apply { this.responseFormat = responseFormat - paramsBuilder.responseFormat(fromClass(responseFormat)) + paramsBuilder.responseFormat(fromClass(responseFormat, localValidation)) } /** @see ChatCompletionCreateParams.Builder.seed */ diff --git a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt similarity index 92% rename from openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt rename to openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index ccbc3926..88696e55 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -7,16 +7,18 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription import com.fasterxml.jackson.databind.JsonNode import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.ObjectNode +import com.openai.errors.OpenAIInvalidDataException import java.util.Optional import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatNoException import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.AfterTestExecutionCallback import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.extension.RegisterExtension -/** Tests the [JsonSchemaValidator] and, in passing, tests the [extractSchema] function. */ -internal class JsonSchemaValidatorTest { +/** Tests for the `StructuredOutputs` functions and the [JsonSchemaValidator]. */ +internal class StructuredOutputsTest { companion object { private const val SCHEMA = "\$schema" private const val SCHEMA_VER = "https://json-schema.org/draft/2020-12/schema" @@ -28,6 +30,8 @@ internal class JsonSchemaValidatorTest { * print them only for failed tests. */ private const val VERBOSE_MODE = false + + private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) } /** @@ -82,7 +86,7 @@ internal class JsonSchemaValidatorTest { assertThat(validator.isValid()).isTrue } - // FIXME: Disabled test until issues (noted below) are resolved. + // TODO: Disabled test until issues (noted below) are resolved. // @Test fun schemaTest_minimalListSchema() { val s: List = listOf() @@ -90,7 +94,7 @@ internal class JsonSchemaValidatorTest { schema = extractSchema(s.javaClass) validator.validate(schema) - // FIXME: Currently, the generated schema looks like this: + // TODO: Currently, the generated schema looks like this: // { // "$schema" : "https://json-schema.org/draft/2020-12/schema", // "type" : "array", @@ -1400,5 +1404,94 @@ internal class JsonSchemaValidatorTest { assertThat(validator.isValid()).isTrue } - private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) + @Test + fun fromJsonSuccess() { + @Suppress("unused") class X(val s: String) + + val x = fromJson("{\"s\" : \"hello\"}", X::class.java) + + assertThat(x.s).isEqualTo("hello") + } + + @Test + fun fromJsonFailure1() { + @Suppress("unused") class X(val s: String) + + // Well-formed JSON, but it does not match the schema of class `X`. + assertThatThrownBy { fromJson("{\"wrong\" : \"hello\"}", X::class.java) } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("Error parsing JSON: {\"wrong\" : \"hello\"}") + } + + @Test + fun fromJsonFailure2() { + @Suppress("unused") class X(val s: String) + + // Malformed JSON. + assertThatThrownBy { fromJson("{\"truncated", X::class.java) } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("Error parsing JSON: {\"truncated") + } + + @Test + @Suppress("unused") + fun fromClassSuccessWithoutValidation() { + // Exceed the maximum nesting depth, but do not enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatNoException().isThrownBy { fromClass(X::class.java, false) } + } + + @Test + fun fromClassSuccessWithValidation() { + @Suppress("unused") class X(val s: String) + + assertThatNoException().isThrownBy { fromClass(X::class.java, true) } + } + + @Test + @Suppress("unused") + fun fromClassFailureWithValidation() { + // Exceed the maximum nesting depth and enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { fromClass(Z::class.java, true) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + @Suppress("unused") + fun fromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `true` by expecting + // a validation error when that argument is not given an explicit value. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { fromClass(Z::class.java) } // Use default for `localValidation` flag. + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } } diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java index bcc46a80..30188fa5 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java @@ -13,8 +13,6 @@ public final class StructuredOutputsClassExample { public static class Person { public String firstName; public String surname; - - @JsonPropertyDescription("The date of birth of the person.") public String dateOfBirth; @Override @@ -44,11 +42,6 @@ public String toString() { public static class Laureates { @JsonPropertyDescription("A list of winners of a Nobel Prize.") public List laureates; - - @Override - public String toString() { - return "Laureates{laureates=" + laureates + '}'; - } } private StructuredOutputsClassExample() {} @@ -63,11 +56,12 @@ public static void main(String[] args) { .model(ChatModel.GPT_4O_MINI) .maxCompletionTokens(2048) .responseFormat(Laureates.class) - .addUserMessage("List some winners of the Nobel Prize in Physics since 2000.") + .addUserMessage("List five winners of the Nobel Prize in Physics.") .build(); client.chat().completions().create(createParams).choices().stream() .flatMap(choice -> choice.message().content().stream()) + .flatMap(laureates -> laureates.laureates.stream()) .forEach(System.out::println); } } From 65ddf35574aac378de0bc1177dae3dd2cb197411 Mon Sep 17 00:00:00 2001 From: D Gardner Date: Tue, 6 May 2025 18:31:44 +0100 Subject: [PATCH 4/9] structured-outputs: changes from code review --- README.md | 48 ++++++++++++----- .../openai/core/JsonSchemaLocalValidation.kt | 19 +++++++ .../com/openai/core/JsonSchemaValidator.kt | 4 +- .../com/openai/core/StructuredOutputs.kt | 17 ++++--- .../completions/ChatCompletionCreateParams.kt | 16 ++++-- .../completions/StructuredChatCompletion.kt | 15 ++++-- .../StructuredChatCompletionCreateParams.kt | 19 +++++-- .../StructuredChatCompletionMessage.kt | 12 ++++- .../blocking/chat/ChatCompletionService.kt | 12 +++-- .../com/openai/core/StructuredOutputsTest.kt | 35 ++++++++----- .../StructuredOutputsClassExample.java | 51 ++++++++++--------- 11 files changed, 170 insertions(+), 78 deletions(-) create mode 100644 openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt diff --git a/README.md b/README.md index 2b8cb881..d7366790 100644 --- a/README.md +++ b/README.md @@ -343,21 +343,23 @@ is a feature that ensures that the model will always generate responses that adh A JSON schema can be defined by creating a [`ResponseFormatJsonSchema`](openai-java-core/src/main/kotlin/com/openai/models/ResponseFormatJsonSchema.kt) and setting it on the input parameters. However, for greater convenience, a JSON schema can instead -be derived automatically from the structure of an arbitrary Java class. The response will then -automatically convert the generated JSON content to an instance of that Java class. +be derived automatically from the structure of an arbitrary Java class. The JSON content from the +response will then be converted automatically to an instance of that Java class. A full, working +example of the use of Structured Outputs with arbitrary Java classes can be seen in +[`StructuredOutputsClassExample`](openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java). Java classes can contain fields declared to be instances of other classes and can use collections: ```java class Person { public String name; - public int yearOfBirth; + public int birthYear; } class Book { public String title; public Person author; - public int yearPublished; + public int publicationYear; } class BookList { @@ -375,7 +377,7 @@ import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() - .addUserMessage("List six famous nineteenth century novels.") + .addUserMessage("List some famous late twentieth century novels.") .model(ChatModel.GPT_4_1) .responseFormat(BookList.class) .build(); @@ -403,11 +405,25 @@ import java.util.Optional; class Book { public String title; public Person author; - public int yearPublished; + public int publicationYear; public Optional isbn; } ``` +Generic type information for fields is retained in the class's metadata, but _generic type erasure_ +applies in other scopes. While, for example, a JSON schema defining an array of strings can be +derived from the `BoolList.books` field with type `List`, a valid JSON schema cannot be +derived from a local variable of that same type, so the following will _not_ work: + +```java +List books = new ArrayList<>(); + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .responseFormat(books.class) + // ... + .build(); +``` + If an error occurs while converting a JSON response to an instance of a Java class, the error message will include the JSON response to assist in diagnosis. For instance, if the response is truncated, the JSON data will be incomplete and cannot be converted to a class instance. If your @@ -435,20 +451,23 @@ well. schema in the request. - **Version Compatibility**: There may be instances where local validation fails while remote validation succeeds. This can occur if the SDK version is outdated compared to the restrictions -enforced by the remote model. +enforced by the remote AI model. - **Disabling Local Validation**: If you encounter compatibility issues and wish to bypass local -validation, you can disable it by passing `false` to the `responseFormat(Class, boolean)` method -when building the parameters. (The default value for this parameter is `true`.) +validation, you can disable it by passing +[`JsonSchemaLocalValidation.NO`](openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt) +to the `responseFormat(Class, JsonSchemaLocalValidation)` method when building the parameters. +(The default value for this parameter is `JsonSchemaLocalValidation.YES`.) ```java +import com.openai.core.JsonSchemaLocalValidation; import com.openai.models.ChatModel; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() - .addUserMessage("List six famous nineteenth century novels.") + .addUserMessage("List some famous late twentieth century novels.") .model(ChatModel.GPT_4_1) - .responseFormat(BookList.class, false) // Disable local validation. + .responseFormat(BookList.class, JsonSchemaLocalValidation.NO) .build(); ``` @@ -470,14 +489,17 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; class Person { @JsonPropertyDescription("The first name and surname of the person") public String name; - public int yearOfBirth; + public int birthYear; + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; } @JsonClassDescription("The details of one published book") class Book { public String title; public Person author; - public int yearPublished; + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; @JsonIgnore public String genre; } diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt new file mode 100644 index 00000000..9a3ae799 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt @@ -0,0 +1,19 @@ +package com.openai.core + +/** + * Options for local validation of JSON schemas derived from arbitrary classes before a request is + * executed. + */ +enum class JsonSchemaLocalValidation { + /** + * Validate the JSON schema locally before the request is executed. The remote AI model will + * also validate the JSON schema. + */ + YES, + + /** + * Do not validate the JSON schema locally before the request is executed. The remote AI model + * will always validate the JSON schema. + */ + NO, +} diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt index 6af40929..85c20b43 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt @@ -201,9 +201,7 @@ internal class JsonSchemaValidator private constructor() { * each new schema. */ fun validate(rootSchema: JsonNode): JsonSchemaValidator { - if (isValidationComplete) { - throw IllegalStateException("Validation already complete.") - } + check(!isValidationComplete) { "Validation already complete." } isValidationComplete = true validateSchema(rootSchema, ROOT_PATH, ROOT_DEPTH) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index ba828d9a..3c6e7dec 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -23,20 +23,19 @@ private val MAPPER = .addModule(JavaTimeModule()) .build() +@JvmSynthetic internal fun fromClass( type: Class, - localValidation: Boolean = true, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, ): ResponseFormatJsonSchema { val schema = extractSchema(type) - if (localValidation) { + if (localValidation == JsonSchemaLocalValidation.YES) { val validator = JsonSchemaValidator.create().validate(schema) - if (!validator.isValid()) { - throw IllegalArgumentException( - "Local validation failed for JSON schema derived from $type:\n" + - validator.errors().joinToString("\n") { " - $it" } - ) + require(validator.isValid()) { + "Local validation failed for JSON schema derived from $type:\n" + + validator.errors().joinToString("\n") { " - $it" } } } @@ -44,12 +43,13 @@ internal fun fromClass( .jsonSchema( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.from(schema)) + .schema(JsonValue.fromJsonNode(schema)) .build() ) .build() } +@JvmSynthetic internal fun extractSchema(type: Class): JsonNode { // Validation is not performed by this function, as it allows extraction of the schema and // validation of the schema to be controlled more easily when unit testing, as no exceptions @@ -76,6 +76,7 @@ internal fun extractSchema(type: Class): JsonNode { return SchemaGenerator(configBuilder.build()).generateSchema(type) } +@JvmSynthetic internal fun fromJson(json: String, type: Class): T = try { MAPPER.readValue(json, type) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index 0fe22b2b..d012ff94 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -19,6 +19,7 @@ import com.openai.core.Enum import com.openai.core.ExcludeMissing import com.openai.core.JsonField import com.openai.core.JsonMissing +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.Params import com.openai.core.allMaxBy @@ -1304,15 +1305,20 @@ private constructor( * * @param responseFormat A class from which a JSON schema will be derived to define the * response format. - * @param localValidation `true` (the default) to validate the JSON schema locally when it - * is generated by this method to confirm that it adheres to the requirements and - * restrictions on JSON schemas imposed by the OpenAI specification; or `false` to disable - * local validation. See the SDK documentation for more details. + * @param localValidation [com.openai.core.JsonSchemaLocalValidation.YES] (the default) to + * validate the JSON schema locally when it is generated by this method to confirm that it + * adheres to the requirements and restrictions on JSON schemas imposed by the OpenAI + * specification; or [com.openai.core.JsonSchemaLocalValidation.NO] to skip local + * validation and rely only on remote validation. See the SDK documentation for more + * details. * @throws IllegalArgumentException If local validation is enabled, but it fails because a * valid JSON schema cannot be derived from the given class. */ @JvmOverloads - fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = + fun responseFormat( + responseFormat: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = StructuredChatCompletionCreateParams.builder() .wrap(responseFormat, this, localValidation) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt index 6ca931a5..62cde6f5 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt @@ -10,9 +10,16 @@ import com.openai.models.completions.CompletionUsage import java.util.Objects import java.util.Optional +/** + * A wrapper for [ChatCompletion] that provides type-safe access to the [choices] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the response will be deserialized. + */ class StructuredChatCompletion( - val responseFormat: Class, - val chatCompletion: ChatCompletion, + @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("chatCompletion") val chatCompletion: ChatCompletion, ) { /** @see ChatCompletion.id */ fun id(): String = chatCompletion.id() @@ -68,8 +75,8 @@ class StructuredChatCompletion( class Choice internal constructor( - internal val responseFormat: Class, - internal val choice: ChatCompletion.Choice, + @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("choice") val choice: ChatCompletion.Choice, ) { /** @see ChatCompletion.Choice.finishReason */ fun finishReason(): FinishReason = choice.finishReason() diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt index 14194ac9..4f6a3a63 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -1,6 +1,7 @@ package com.openai.models.chat.completions import com.openai.core.JsonField +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.checkRequired import com.openai.core.fromClass @@ -11,9 +12,18 @@ import com.openai.models.ReasoningEffort import java.util.Objects import java.util.Optional +/** + * A wrapper for [ChatCompletionCreateParams] that provides a type-safe [Builder] that can record + * the type of the [responseFormat] used to derive a JSON schema from an arbitrary class when using + * the _Structured Outputs_ feature. When a JSON response is received, it is deserialized to am + * instance of that type. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class that will be used to derive the JSON schema in the request and to + * which the JSON response will be deserialized. + */ class StructuredChatCompletionCreateParams internal constructor( - val responseFormat: Class, + @get:JvmName("responseFormat") val responseFormat: Class, /** * The raw, underlying chat completion create parameters wrapped by this structured instance of * the parameters. @@ -33,7 +43,7 @@ internal constructor( internal fun wrap( responseFormat: Class, paramsBuilder: ChatCompletionCreateParams.Builder, - localValidation: Boolean, + localValidation: JsonSchemaLocalValidation, ) = apply { this.responseFormat = responseFormat this.paramsBuilder = paramsBuilder @@ -396,7 +406,10 @@ internal constructor( * @see ChatCompletionCreateParams.Builder.responseFormat */ @JvmOverloads - fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = apply { + fun responseFormat( + responseFormat: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { this.responseFormat = responseFormat paramsBuilder.responseFormat(fromClass(responseFormat, localValidation)) } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt index 519596ef..b833dd47 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt @@ -7,10 +7,18 @@ import com.openai.models.chat.completions.ChatCompletionMessage.FunctionCall import java.util.Objects import java.util.Optional +/** + * A wrapper for [ChatCompletionMessage] that provides type-safe access to the [content] when using + * the _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary + * class. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * [content] is called. + */ class StructuredChatCompletionMessage internal constructor( - val responseFormat: Class, - val chatCompletionMessage: ChatCompletionMessage, + @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("chatCompletionMessage") val chatCompletionMessage: ChatCompletionMessage, ) { private val content: JsonField by lazy { diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 28818c45..6985e05e 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -58,12 +58,14 @@ interface ChatCompletionService { /** @see create */ fun create( params: StructuredChatCompletionCreateParams + ): StructuredChatCompletion = create(params, RequestOptions.none()) + + /** @see create */ + fun create( + params: StructuredChatCompletionCreateParams, + requestOptions: RequestOptions = RequestOptions.none(), ): StructuredChatCompletion = - StructuredChatCompletion( - params.responseFormat, - // Normal, non-generic create method call via `ChatCompletionCreateParams`. - create(params.rawParams), - ) + StructuredChatCompletion(params.responseFormat, create(params.rawParams, requestOptions)) /** * **Starting a new project?** We recommend trying diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index 88696e55..6ea7454b 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -86,22 +86,24 @@ internal class StructuredOutputsTest { assertThat(validator.isValid()).isTrue } - // TODO: Disabled test until issues (noted below) are resolved. - // @Test + @Test fun schemaTest_minimalListSchema() { val s: List = listOf() schema = extractSchema(s.javaClass) validator.validate(schema) - // TODO: Currently, the generated schema looks like this: + // Currently, the generated schema looks like this: + // // { // "$schema" : "https://json-schema.org/draft/2020-12/schema", // "type" : "array", // "items" : { } // } - // That causes an error, as the `"items"` object is empty when it should be a valid - // sub-schema. Something like this is what is expected: + // + // That causes an error, as the `"items"` object is empty when it should be a valid + // sub-schema. Something like this is what would be valid: + // // { // "$schema" : "https://json-schema.org/draft/2020-12/schema", // "type" : "array", @@ -109,10 +111,15 @@ internal class StructuredOutputsTest { // "type" : "string" // } // } - // It might be presumed that type erasure is the cause of the missing field. However, the - // `schemaTest_listFieldSchema` method (below) seems to be able to produce the expected - // `"items"` object when it is defined as a class property, so, well ... huh? - assertThat(validator.isValid()).isTrue + // + // The reason for the failure is that generic type information is erased for scopes like + // local variables, but generic type information for fields is retained as part of the class + // metadata. This is the expected behavior in Java, so this test expects an invalid schema. + assertThat(validator.isValid()).isFalse + assertThat(validator.errors()).hasSize(2) + assertThat(validator.errors()[0]).isEqualTo("#/items: Schema or sub-schema is empty.") + assertThat(validator.errors()[1]) + .isEqualTo("#/items: Expected exactly one of 'type' or 'anyOf' or '$REF'.") } @Test @@ -1444,14 +1451,18 @@ internal class StructuredOutputsTest { class Y(val x: X) class Z(val y: Y) - assertThatNoException().isThrownBy { fromClass(X::class.java, false) } + assertThatNoException().isThrownBy { + fromClass(X::class.java, JsonSchemaLocalValidation.NO) + } } @Test fun fromClassSuccessWithValidation() { @Suppress("unused") class X(val s: String) - assertThatNoException().isThrownBy { fromClass(X::class.java, true) } + assertThatNoException().isThrownBy { + fromClass(X::class.java, JsonSchemaLocalValidation.YES) + } } @Test @@ -1465,7 +1476,7 @@ internal class StructuredOutputsTest { class Y(val x: X) class Z(val y: Y) - assertThatThrownBy { fromClass(Z::class.java, true) } + assertThatThrownBy { fromClass(Z::class.java, JsonSchemaLocalValidation.YES) } .isExactlyInstanceOf(IllegalArgumentException::class.java) .hasMessage( "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java index 30188fa5..3f65a991 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java @@ -1,5 +1,6 @@ package com.openai.example; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; @@ -11,37 +12,41 @@ public final class StructuredOutputsClassExample { public static class Person { - public String firstName; - public String surname; - public String dateOfBirth; + @JsonPropertyDescription("The first name and surname of the person.") + public String name; + + public int birthYear; + + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; @Override public String toString() { - return "Person{firstName=" + firstName + ", surname=" + surname + ", dateOfBirth=" + dateOfBirth + '}'; + return name + " (" + birthYear + '-' + deathYear + ')'; } } - public static class Laureate { - public Person person; - public String majorAchievement; - public int yearWon; + public static class Book { + public String title; + + public Person author; + + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + + public String genre; - @JsonPropertyDescription("The share of the prize money won by the Nobel Laureate.") - public double prizeMoney; + @JsonIgnore + public String isbn; @Override public String toString() { - return "Laureate{person=" - + person + ", majorAchievement=" - + majorAchievement + ", yearWon=" - + yearWon + ", prizeMoney=" - + prizeMoney + '}'; + return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; } } - public static class Laureates { - @JsonPropertyDescription("A list of winners of a Nobel Prize.") - public List laureates; + public static class BookList { + public List books; } private StructuredOutputsClassExample() {} @@ -52,16 +57,16 @@ public static void main(String[] args) { // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClient client = OpenAIOkHttpClient.fromEnv(); - StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() .model(ChatModel.GPT_4O_MINI) .maxCompletionTokens(2048) - .responseFormat(Laureates.class) - .addUserMessage("List five winners of the Nobel Prize in Physics.") + .responseFormat(BookList.class) + .addUserMessage("List some famous late twentieth century novels.") .build(); client.chat().completions().create(createParams).choices().stream() .flatMap(choice -> choice.message().content().stream()) - .flatMap(laureates -> laureates.laureates.stream()) - .forEach(System.out::println); + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(" - " + book)); } } From 118b16c39d89a701ffad8b41ccd1c55dcd368921 Mon Sep 17 00:00:00 2001 From: D Gardner Date: Wed, 7 May 2025 11:30:56 +0100 Subject: [PATCH 5/9] structured-outputs: added 'strict' flag --- .../kotlin/com/openai/core/StructuredOutputs.kt | 3 +++ .../com/openai/core/StructuredOutputsTest.kt | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index 3c6e7dec..6b1889ff 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -44,6 +44,9 @@ internal fun fromClass( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") .schema(JsonValue.fromJsonNode(schema)) + // Ensure the model's output strictly adheres to this JSON schema. This is the + // essential "ON switch" for Structured Outputs. + .strict(true) .build() ) .build() diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index 6ea7454b..2c1eb885 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -1440,6 +1440,18 @@ internal class StructuredOutputsTest { .hasMessage("Error parsing JSON: {\"truncated") } + @Test + fun fromClassEnablesStrictAdherenceToSchema() { + @Suppress("unused") class X(val s: String) + + val jsonSchema = fromClass(X::class.java) + + // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform + // to the JSON schema. + assertThat(jsonSchema.jsonSchema().strict()).isPresent + assertThat(jsonSchema.jsonSchema().strict().get()).isTrue + } + @Test @Suppress("unused") fun fromClassSuccessWithoutValidation() { @@ -1452,7 +1464,7 @@ internal class StructuredOutputsTest { class Z(val y: Y) assertThatNoException().isThrownBy { - fromClass(X::class.java, JsonSchemaLocalValidation.NO) + fromClass(Z::class.java, JsonSchemaLocalValidation.NO) } } From 9d5b95655b253429b07085e6be37f6ee40aaf639 Mon Sep 17 00:00:00 2001 From: D Gardner Date: Wed, 14 May 2025 13:50:20 +0100 Subject: [PATCH 6/9] structured-outputs: support for Responses API, review changes --- README.md | 29 +- .../com/openai/core/StructuredOutputs.kt | 74 ++- .../completions/ChatCompletionCreateParams.kt | 13 +- .../completions/StructuredChatCompletion.kt | 74 +-- .../StructuredChatCompletionCreateParams.kt | 34 +- .../StructuredChatCompletionMessage.kt | 49 +- .../models/responses/ResponseCreateParams.kt | 23 + .../models/responses/StructuredResponse.kt | 230 ++++++++ .../StructuredResponseCreateParams.kt | 528 ++++++++++++++++++ .../responses/StructuredResponseOutputItem.kt | 189 +++++++ .../StructuredResponseOutputMessage.kt | 199 +++++++ .../services/blocking/ResponseService.kt | 13 + .../blocking/chat/ChatCompletionService.kt | 2 +- .../com/openai/core/StructuredOutputsTest.kt | 22 +- .../openai/core/StructuredOutputsTestUtils.kt | 366 ++++++++++++ .../ChatCompletionCreateParamsTest.kt | 2 +- ...tructuredChatCompletionCreateParamsTest.kt | 253 ++------- .../StructuredChatCompletionMessageTest.kt | 66 +-- .../StructuredChatCompletionTest.kt | 184 ++---- .../StructuredResponseCreateParamsTest.kt | 245 ++++++++ .../StructuredResponseOutputItemTest.kt | 204 +++++++ .../StructuredResponseOutputMessageTest.kt | 298 ++++++++++ .../responses/StructuredResponseTest.kt | 245 ++++++++ .../ResponsesStructuredOutputsExample.java | 73 +++ 24 files changed, 2901 insertions(+), 514 deletions(-) create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt create mode 100644 openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java diff --git a/README.md b/README.md index d7366790..105aa4b9 100644 --- a/README.md +++ b/README.md @@ -411,15 +411,15 @@ class Book { ``` Generic type information for fields is retained in the class's metadata, but _generic type erasure_ -applies in other scopes. While, for example, a JSON schema defining an array of strings can be -derived from the `BoolList.books` field with type `List`, a valid JSON schema cannot be -derived from a local variable of that same type, so the following will _not_ work: +applies in other scopes. While, for example, a JSON schema defining an array of books can be derived +from the `BookList.books` field with type `List`, a valid JSON schema cannot be derived from a +local variable of that same type, so the following will _not_ work: ```java -List books = new ArrayList<>(); +List books = new ArrayList<>(); -StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() - .responseFormat(books.class) +StructuredChatCompletionCreateParams> params = ChatCompletionCreateParams.builder() + .responseFormat(books.getClass()) // ... .build(); ``` @@ -474,6 +474,23 @@ StructuredChatCompletionCreateParams params = ChatCompletionCreatePara By following these guidelines, you can ensure that your structured outputs conform to the necessary schema requirements and minimize the risk of remote validation errors. +### Usage with the Responses API + +_Structured Outputs_ are also supported for the Responses API. The usage is the same as described +except where the Responses API differs slightly from the Chat Completions API. Pass the top-level +class to `text(Class)` when building the parameters and then access an instance of the class from +the generated message content in the response. + +You can start building the parameters with an instance of +[`ResponseCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt) +or +[`StructuredResponseCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt). +If you start with the former (which allows for more compact code) the builder type will change to +the latter when `ResponseCreateParams.Builder.text(Class)` is called. + +For a full example of the usage of _Structured Outputs_ with the Responses API, see +[`ResponsesStructuredOutputsExample`](openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java). + ### Annotating classes and JSON schemas You can use annotations to add further information to the JSON schema derived from your Java diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index 6b1889ff..747995e1 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -12,6 +12,8 @@ import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder import com.github.victools.jsonschema.module.jackson.JacksonModule import com.openai.errors.OpenAIInvalidDataException import com.openai.models.ResponseFormatJsonSchema +import com.openai.models.responses.ResponseFormatTextJsonSchemaConfig +import com.openai.models.responses.ResponseTextConfig // The SDK `ObjectMappers.jsonMapper()` requires that all fields of classes be marked with // `@JsonProperty`, which is not desirable in this context, as it impedes usability. Therefore, a @@ -23,11 +25,31 @@ private val MAPPER = .addModule(JavaTimeModule()) .build() +/** + * Builds a response format using a JSON schema derived from the structure of an arbitrary Java + * class. + */ @JvmSynthetic -internal fun fromClass( +internal fun responseFormatFromClass( type: Class, localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, -): ResponseFormatJsonSchema { +): ResponseFormatJsonSchema = + ResponseFormatJsonSchema.builder() + .jsonSchema( + ResponseFormatJsonSchema.JsonSchema.builder() + .name("json-schema-from-${type.simpleName}") + .schema(JsonValue.fromJsonNode(extractAndValidateSchema(type, localValidation))) + // Ensure the model's output strictly adheres to this JSON schema. This is the + // essential "ON switch" for Structured Outputs. + .strict(true) + .build() + ) + .build() + +private fun extractAndValidateSchema( + type: Class, + localValidation: JsonSchemaLocalValidation, +): JsonNode { val schema = extractSchema(type) if (localValidation == JsonSchemaLocalValidation.YES) { @@ -39,24 +61,48 @@ internal fun fromClass( } } - return ResponseFormatJsonSchema.builder() - .jsonSchema( - ResponseFormatJsonSchema.JsonSchema.builder() + return schema +} + +/** + * Builds a text configuration with its format set to a JSON schema derived from the structure of an + * arbitrary Java class. + */ +@JvmSynthetic +internal fun textConfigFromClass( + type: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, +): ResponseTextConfig = + ResponseTextConfig.builder() + .format( + ResponseFormatTextJsonSchemaConfig.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.fromJsonNode(schema)) + .schema( + ResponseFormatTextJsonSchemaConfig.Schema.builder() + .additionalProperties( + extractAndValidateSchema(type, localValidation) + .fields() + .asSequence() + .associate { it.key to JsonValue.fromJsonNode(it.value) } + ) + .build() + ) // Ensure the model's output strictly adheres to this JSON schema. This is the // essential "ON switch" for Structured Outputs. .strict(true) .build() ) .build() -} +/** + * Derives a JSON schema from the structure of an arbitrary Java class. + * + * Validation is not performed by this function, as it allows extraction of the schema and + * validation of the schema to be controlled more easily when unit testing, as no exceptions will be + * thrown and any recorded validation errors can be inspected at leisure by the tests. + */ @JvmSynthetic internal fun extractSchema(type: Class): JsonNode { - // Validation is not performed by this function, as it allows extraction of the schema and - // validation of the schema to be controlled more easily when unit testing, as no exceptions - // will be thrown and any recorded validation errors can be inspected at leisure by the tests. val configBuilder = SchemaGeneratorConfigBuilder( com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, @@ -79,10 +125,14 @@ internal fun extractSchema(type: Class): JsonNode { return SchemaGenerator(configBuilder.build()).generateSchema(type) } +/** + * Creates an instance of a Java class using data from a JSON. The JSON data should conform to the + * JSON schema previously extracted from the Java class. + */ @JvmSynthetic -internal fun fromJson(json: String, type: Class): T = +internal fun responseTypeFromJson(json: String, responseType: Class): T = try { - MAPPER.readValue(json, type) + MAPPER.readValue(json, responseType) } catch (e: Exception) { // The JSON document is included in the exception message to aid diagnosis of the problem. // It is the responsibility of the SDK user to ensure that exceptions that may contain diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index d012ff94..85d2582f 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -1299,11 +1299,12 @@ private constructor( } /** - * Sets response format to a JSON schema derived from the structure of the given class. This - * changes the builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that - * will build a [StructuredChatCompletionCreateParams] instance when `build()` is called. + * Sets the response format to a JSON schema derived from the structure of the given class. + * This changes the builder to a type-safe [StructuredChatCompletionCreateParams.Builder] + * that will build a [StructuredChatCompletionCreateParams] instance when `build()` is + * called. * - * @param responseFormat A class from which a JSON schema will be derived to define the + * @param responseType A class from which a JSON schema will be derived to define the * response format. * @param localValidation [com.openai.core.JsonSchemaLocalValidation.YES] (the default) to * validate the JSON schema locally when it is generated by this method to confirm that it @@ -1316,11 +1317,11 @@ private constructor( */ @JvmOverloads fun responseFormat( - responseFormat: Class, + responseType: Class, localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, ) = StructuredChatCompletionCreateParams.builder() - .wrap(responseFormat, this, localValidation) + .wrap(responseType, this, localValidation) /** * This feature is in Beta. If specified, our system will make a best effort to sample diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt index 62cde6f5..872135a4 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt @@ -18,101 +18,101 @@ import java.util.Optional * @param T The type of the class to which the JSON data in the response will be deserialized. */ class StructuredChatCompletion( - @get:JvmName("responseFormat") val responseFormat: Class, - @get:JvmName("chatCompletion") val chatCompletion: ChatCompletion, + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawChatCompletion") val rawChatCompletion: ChatCompletion, ) { /** @see ChatCompletion.id */ - fun id(): String = chatCompletion.id() + fun id(): String = rawChatCompletion.id() private val choices by lazy { - chatCompletion._choices().map { choices -> choices.map { Choice(responseFormat, it) } } + rawChatCompletion._choices().map { choices -> choices.map { Choice(responseType, it) } } } /** @see ChatCompletion.choices */ fun choices(): List> = choices.getRequired("choices") /** @see ChatCompletion.created */ - fun created(): Long = chatCompletion.created() + fun created(): Long = rawChatCompletion.created() /** @see ChatCompletion.model */ - fun model(): String = chatCompletion.model() + fun model(): String = rawChatCompletion.model() /** @see ChatCompletion._object_ */ - fun _object_(): JsonValue = chatCompletion._object_() + fun _object_(): JsonValue = rawChatCompletion._object_() /** @see ChatCompletion.serviceTier */ - fun serviceTier(): Optional = chatCompletion.serviceTier() + fun serviceTier(): Optional = rawChatCompletion.serviceTier() /** @see ChatCompletion.systemFingerprint */ - fun systemFingerprint(): Optional = chatCompletion.systemFingerprint() + fun systemFingerprint(): Optional = rawChatCompletion.systemFingerprint() /** @see ChatCompletion.usage */ - fun usage(): Optional = chatCompletion.usage() + fun usage(): Optional = rawChatCompletion.usage() /** @see ChatCompletion._id */ - fun _id(): JsonField = chatCompletion._id() + fun _id(): JsonField = rawChatCompletion._id() /** @see ChatCompletion._choices */ fun _choices(): JsonField>> = choices /** @see ChatCompletion._created */ - fun _created(): JsonField = chatCompletion._created() + fun _created(): JsonField = rawChatCompletion._created() /** @see ChatCompletion._model */ - fun _model(): JsonField = chatCompletion._model() + fun _model(): JsonField = rawChatCompletion._model() /** @see ChatCompletion._serviceTier */ - fun _serviceTier(): JsonField = chatCompletion._serviceTier() + fun _serviceTier(): JsonField = rawChatCompletion._serviceTier() /** @see ChatCompletion._systemFingerprint */ - fun _systemFingerprint(): JsonField = chatCompletion._systemFingerprint() + fun _systemFingerprint(): JsonField = rawChatCompletion._systemFingerprint() /** @see ChatCompletion._usage */ - fun _usage(): JsonField = chatCompletion._usage() + fun _usage(): JsonField = rawChatCompletion._usage() /** @see ChatCompletion._additionalProperties */ - fun _additionalProperties(): Map = chatCompletion._additionalProperties() + fun _additionalProperties(): Map = rawChatCompletion._additionalProperties() class Choice internal constructor( - @get:JvmName("responseFormat") val responseFormat: Class, - @get:JvmName("choice") val choice: ChatCompletion.Choice, + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawChoice") val rawChoice: ChatCompletion.Choice, ) { /** @see ChatCompletion.Choice.finishReason */ - fun finishReason(): FinishReason = choice.finishReason() + fun finishReason(): FinishReason = rawChoice.finishReason() /** @see ChatCompletion.Choice.index */ - fun index(): Long = choice.index() + fun index(): Long = rawChoice.index() /** @see ChatCompletion.Choice.logprobs */ - fun logprobs(): Optional = choice.logprobs() + fun logprobs(): Optional = rawChoice.logprobs() /** @see ChatCompletion.Choice._finishReason */ - fun _finishReason(): JsonField = choice._finishReason() + fun _finishReason(): JsonField = rawChoice._finishReason() private val message by lazy { - choice._message().map { StructuredChatCompletionMessage(responseFormat, it) } + rawChoice._message().map { StructuredChatCompletionMessage(responseType, it) } } /** @see ChatCompletion.Choice.message */ fun message(): StructuredChatCompletionMessage = message.getRequired("message") /** @see ChatCompletion.Choice._index */ - fun _index(): JsonField = choice._index() + fun _index(): JsonField = rawChoice._index() /** @see ChatCompletion.Choice._logprobs */ - fun _logprobs(): JsonField = choice._logprobs() + fun _logprobs(): JsonField = rawChoice._logprobs() /** @see ChatCompletion.Choice._message */ fun _message(): JsonField> = message /** @see ChatCompletion.Choice._additionalProperties */ - fun _additionalProperties(): Map = choice._additionalProperties() + fun _additionalProperties(): Map = rawChoice._additionalProperties() /** @see ChatCompletion.Choice.validate */ fun validate(): Choice = apply { message().validate() - choice.validate() + rawChoice.validate() } /** @see ChatCompletion.Choice.isValid */ @@ -130,22 +130,22 @@ class StructuredChatCompletion( } return other is Choice<*> && - responseFormat == other.responseFormat && - choice == other.choice + responseType == other.responseType && + rawChoice == other.rawChoice } - private val hashCode: Int by lazy { Objects.hash(responseFormat, choice) } + private val hashCode: Int by lazy { Objects.hash(responseType, rawChoice) } override fun hashCode(): Int = hashCode override fun toString() = - "${javaClass.simpleName}{responseFormat=$responseFormat, choice=$choice}" + "${javaClass.simpleName}{responseType=$responseType, rawChoice=$rawChoice}" } /** @see ChatCompletion.validate */ fun validate() = apply { choices().forEach { it.validate() } - chatCompletion.validate() + rawChatCompletion.validate() } /** @see ChatCompletion.isValid */ @@ -163,14 +163,14 @@ class StructuredChatCompletion( } return other is StructuredChatCompletion<*> && - responseFormat == other.responseFormat && - chatCompletion == other.chatCompletion + responseType == other.responseType && + rawChatCompletion == other.rawChatCompletion } - private val hashCode: Int by lazy { Objects.hash(responseFormat, chatCompletion) } + private val hashCode: Int by lazy { Objects.hash(responseType, rawChatCompletion) } override fun hashCode(): Int = hashCode override fun toString() = - "${javaClass.simpleName}{responseFormat=$responseFormat, chatCompletion=$chatCompletion}" + "${javaClass.simpleName}{responseType=$responseType, rawChatCompletion=$rawChatCompletion}" } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt index 4f6a3a63..73741b58 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -4,9 +4,9 @@ import com.openai.core.JsonField import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.checkRequired -import com.openai.core.fromClass import com.openai.core.http.Headers import com.openai.core.http.QueryParams +import com.openai.core.responseFormatFromClass import com.openai.models.ChatModel import com.openai.models.ReasoningEffort import java.util.Objects @@ -14,16 +14,16 @@ import java.util.Optional /** * A wrapper for [ChatCompletionCreateParams] that provides a type-safe [Builder] that can record - * the type of the [responseFormat] used to derive a JSON schema from an arbitrary class when using - * the _Structured Outputs_ feature. When a JSON response is received, it is deserialized to am - * instance of that type. See the SDK documentation for more details on _Structured Outputs_. + * the [responseType] used to derive a JSON schema from an arbitrary class when using the + * _Structured Outputs_ feature. When a JSON response is received, it is deserialized to am instance + * of that type. See the SDK documentation for more details on _Structured Outputs_. * * @param T The type of the class that will be used to derive the JSON schema in the request and to * which the JSON response will be deserialized. */ class StructuredChatCompletionCreateParams internal constructor( - @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("responseType") val responseType: Class, /** * The raw, underlying chat completion create parameters wrapped by this structured instance of * the parameters. @@ -36,19 +36,19 @@ internal constructor( } class Builder internal constructor() { - private var responseFormat: Class? = null + private var responseType: Class? = null private var paramsBuilder = ChatCompletionCreateParams.builder() @JvmSynthetic internal fun wrap( - responseFormat: Class, + responseType: Class, paramsBuilder: ChatCompletionCreateParams.Builder, localValidation: JsonSchemaLocalValidation, ) = apply { - this.responseFormat = responseFormat + this.responseType = responseType this.paramsBuilder = paramsBuilder // Convert the class to a JSON schema and apply it to the delegate `Builder`. - responseFormat(responseFormat, localValidation) + responseFormat(responseType, localValidation) } /** Injects a given `ChatCompletionCreateParams.Builder`. For use only when testing. */ @@ -407,11 +407,11 @@ internal constructor( */ @JvmOverloads fun responseFormat( - responseFormat: Class, + responseType: Class, localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, ) = apply { - this.responseFormat = responseFormat - paramsBuilder.responseFormat(fromClass(responseFormat, localValidation)) + this.responseType = responseType + paramsBuilder.responseFormat(responseFormatFromClass(responseType, localValidation)) } /** @see ChatCompletionCreateParams.Builder.seed */ @@ -741,7 +741,7 @@ internal constructor( */ fun build() = StructuredChatCompletionCreateParams( - checkRequired("responseFormat", responseFormat), + checkRequired("responseType", responseType), paramsBuilder.build(), ) } @@ -752,12 +752,14 @@ internal constructor( } return other is StructuredChatCompletionCreateParams<*> && - responseFormat == other.responseFormat && + responseType == other.responseType && rawParams == other.rawParams } - override fun hashCode(): Int = Objects.hash(responseFormat, rawParams) + private val hashCode: Int by lazy { Objects.hash(responseType, rawParams) } + + override fun hashCode(): Int = hashCode override fun toString() = - "${javaClass.simpleName}{responseFormat=$responseFormat, params=$rawParams}" + "${javaClass.simpleName}{responseType=$responseType, rawParams=$rawParams}" } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt index b833dd47..27e8a3c7 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt @@ -2,7 +2,7 @@ package com.openai.models.chat.completions import com.openai.core.JsonField import com.openai.core.JsonValue -import com.openai.core.fromJson +import com.openai.core.responseTypeFromJson import com.openai.models.chat.completions.ChatCompletionMessage.FunctionCall import java.util.Objects import java.util.Optional @@ -17,69 +17,64 @@ import java.util.Optional */ class StructuredChatCompletionMessage internal constructor( - @get:JvmName("responseFormat") val responseFormat: Class, - @get:JvmName("chatCompletionMessage") val chatCompletionMessage: ChatCompletionMessage, + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawMessage") val rawMessage: ChatCompletionMessage, ) { private val content: JsonField by lazy { - chatCompletionMessage._content().map { fromJson(it, responseFormat) } + rawMessage._content().map { responseTypeFromJson(it, responseType) } } /** @see ChatCompletionMessage.content */ fun content(): Optional = content.getOptional("content") /** @see ChatCompletionMessage.refusal */ - fun refusal(): Optional = chatCompletionMessage.refusal() + fun refusal(): Optional = rawMessage.refusal() /** @see ChatCompletionMessage._role */ - fun _role(): JsonValue = chatCompletionMessage._role() + fun _role(): JsonValue = rawMessage._role() /** @see ChatCompletionMessage.annotations */ - fun annotations(): Optional> = - chatCompletionMessage.annotations() + fun annotations(): Optional> = rawMessage.annotations() /** @see ChatCompletionMessage.audio */ - fun audio(): Optional = chatCompletionMessage.audio() + fun audio(): Optional = rawMessage.audio() /** @see ChatCompletionMessage.functionCall */ - @Deprecated("deprecated") - fun functionCall(): Optional = chatCompletionMessage.functionCall() + @Deprecated("deprecated") fun functionCall(): Optional = rawMessage.functionCall() /** @see ChatCompletionMessage.toolCalls */ - fun toolCalls(): Optional> = - chatCompletionMessage.toolCalls() + fun toolCalls(): Optional> = rawMessage.toolCalls() /** @see ChatCompletionMessage._content */ fun _content(): JsonField = content /** @see ChatCompletionMessage._refusal */ - fun _refusal(): JsonField = chatCompletionMessage._refusal() + fun _refusal(): JsonField = rawMessage._refusal() /** @see ChatCompletionMessage._annotations */ fun _annotations(): JsonField> = - chatCompletionMessage._annotations() + rawMessage._annotations() /** @see ChatCompletionMessage._audio */ - fun _audio(): JsonField = chatCompletionMessage._audio() + fun _audio(): JsonField = rawMessage._audio() /** @see ChatCompletionMessage._functionCall */ @Deprecated("deprecated") - fun _functionCall(): JsonField = chatCompletionMessage._functionCall() + fun _functionCall(): JsonField = rawMessage._functionCall() /** @see ChatCompletionMessage._toolCalls */ - fun _toolCalls(): JsonField> = - chatCompletionMessage._toolCalls() + fun _toolCalls(): JsonField> = rawMessage._toolCalls() /** @see ChatCompletionMessage._additionalProperties */ - fun _additionalProperties(): Map = - chatCompletionMessage._additionalProperties() + fun _additionalProperties(): Map = rawMessage._additionalProperties() /** @see ChatCompletionMessage.validate */ // `content()` is not included in the validation by the delegate method, so just call it. - fun validate(): ChatCompletionMessage = chatCompletionMessage.validate() + fun validate(): ChatCompletionMessage = rawMessage.validate() /** @see ChatCompletionMessage.isValid */ - fun isValid(): Boolean = chatCompletionMessage.isValid() + fun isValid(): Boolean = rawMessage.isValid() override fun equals(other: Any?): Boolean { if (this === other) { @@ -87,14 +82,14 @@ internal constructor( } return other is StructuredChatCompletionMessage<*> && - responseFormat == other.responseFormat && - chatCompletionMessage == other.chatCompletionMessage + responseType == other.responseType && + rawMessage == other.rawMessage } - private val hashCode: Int by lazy { Objects.hash(responseFormat, chatCompletionMessage) } + private val hashCode: Int by lazy { Objects.hash(responseType, rawMessage) } override fun hashCode(): Int = hashCode override fun toString() = - "${javaClass.simpleName}{responseFormat=$responseFormat, chatCompletionMessage=$chatCompletionMessage}" + "${javaClass.simpleName}{responseType=$responseType, rawMessage=$rawMessage}" } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt index 6b512ed9..345f6a93 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt @@ -19,6 +19,7 @@ import com.openai.core.Enum import com.openai.core.ExcludeMissing import com.openai.core.JsonField import com.openai.core.JsonMissing +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.Params import com.openai.core.allMaxBy @@ -784,6 +785,28 @@ private constructor( */ fun text(text: JsonField) = apply { body.text(text) } + /** + * Sets the text configuration's format to a JSON schema derived from the structure of the + * given class. This changes the builder to a type-safe + * [StructuredResponseCreateParams.Builder] that will build a + * [StructuredResponseCreateParams] instance when `build()` is called. + * + * @param responseType A class from which a JSON schema will be derived to define the text + * configuration's format. + * @param localValidation [JsonSchemaLocalValidation.YES] (the default) to validate the JSON + * schema locally when it is generated by this method to confirm that it adheres to the + * requirements and restrictions on JSON schemas imposed by the OpenAI specification; or + * [JsonSchemaLocalValidation.NO] to skip local validation and rely only on remote + * validation. See the SDK documentation for more details. + * @throws IllegalArgumentException If local validation is enabled, but it fails because a + * valid JSON schema cannot be derived from the given class. + */ + @JvmOverloads + fun text( + responseType: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = StructuredResponseCreateParams.builder().wrap(responseType, this, localValidation) + /** * How the model should select which tool (or tools) to use when generating a response. See * the `tools` parameter to see how to specify which tools the model can call. diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt new file mode 100644 index 00000000..d71450cf --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt @@ -0,0 +1,230 @@ +package com.openai.models.responses + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.Reasoning +import com.openai.models.ResponsesModel +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [Response] that provides type-safe access to the [output] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the response will be deserialized. + */ +class StructuredResponse( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawResponse") val rawResponse: Response, +) { + /** @see Response.id */ + fun id(): String = rawResponse.id() + + /** @see Response.createdAt */ + fun createdAt(): Double = rawResponse.createdAt() + + /** @see Response.error */ + fun error(): Optional = rawResponse.error() + + /** @see Response.incompleteDetails */ + fun incompleteDetails(): Optional = rawResponse.incompleteDetails() + + /** @see Response.instructions */ + fun instructions(): Optional = rawResponse.instructions() + + /** @see Response.metadata */ + fun metadata(): Optional = rawResponse.metadata() + + /** @see Response.model */ + fun model(): ResponsesModel = rawResponse.model() + + /** @see Response._object_ */ + fun _object_(): JsonValue = rawResponse._object_() + + private val output by lazy { + rawResponse._output().map { outputs -> + outputs.map { StructuredResponseOutputItem(responseType, it) } + } + } + + /** @see Response.output */ + fun output(): List> = output.getRequired("output") + + /** @see Response.parallelToolCalls */ + fun parallelToolCalls(): Boolean = rawResponse.parallelToolCalls() + + /** @see Response.temperature */ + fun temperature(): Optional = rawResponse.temperature() + + /** @see Response.toolChoice */ + fun toolChoice(): Response.ToolChoice = rawResponse.toolChoice() + + /** @see Response.tools */ + fun tools(): List = rawResponse.tools() + + /** @see Response.topP */ + fun topP(): Optional = rawResponse.topP() + + /** @see Response.maxOutputTokens */ + fun maxOutputTokens(): Optional = rawResponse.maxOutputTokens() + + /** @see Response.previousResponseId */ + fun previousResponseId(): Optional = rawResponse.previousResponseId() + + /** @see Response.reasoning */ + fun reasoning(): Optional = rawResponse.reasoning() + + /** @see Response.serviceTier */ + fun serviceTier(): Optional = rawResponse.serviceTier() + + /** @see Response.status */ + fun status(): Optional = rawResponse.status() + + /** @see Response.text */ + fun text(): Optional = rawResponse.text() + + /** @see Response.truncation */ + fun truncation(): Optional = rawResponse.truncation() + + /** @see Response.usage */ + fun usage(): Optional = rawResponse.usage() + + /** @see Response.user */ + fun user(): Optional = rawResponse.user() + + /** @see Response._id */ + fun _id(): JsonField = rawResponse._id() + + /** @see Response._createdAt */ + fun _createdAt(): JsonField = rawResponse._createdAt() + + /** @see Response._error */ + fun _error(): JsonField = rawResponse._error() + + /** @see Response._incompleteDetails */ + fun _incompleteDetails(): JsonField = + rawResponse._incompleteDetails() + + /** @see Response._instructions */ + fun _instructions(): JsonField = rawResponse._instructions() + + /** @see Response._metadata */ + fun _metadata(): JsonField = rawResponse._metadata() + + /** @see Response._model */ + fun _model(): JsonField = rawResponse._model() + + /** @see Response._output */ + fun _output(): JsonField>> = output + + /** @see Response._parallelToolCalls */ + fun _parallelToolCalls(): JsonField = rawResponse._parallelToolCalls() + + /** @see Response._temperature */ + fun _temperature(): JsonField = rawResponse._temperature() + + /** @see Response._toolChoice */ + fun _toolChoice(): JsonField = rawResponse._toolChoice() + + /** @see Response._tools */ + fun _tools(): JsonField> = rawResponse._tools() + + /** @see Response._topP */ + fun _topP(): JsonField = rawResponse._topP() + + /** @see Response._maxOutputTokens */ + fun _maxOutputTokens(): JsonField = rawResponse._maxOutputTokens() + + /** @see Response._previousResponseId */ + fun _previousResponseId(): JsonField = rawResponse._previousResponseId() + + /** @see Response._reasoning */ + fun _reasoning(): JsonField = rawResponse._reasoning() + + /** @see Response._serviceTier */ + fun _serviceTier(): JsonField = rawResponse._serviceTier() + + /** @see Response._status */ + fun _status(): JsonField = rawResponse._status() + + /** @see Response._text */ + fun _text(): JsonField = rawResponse._text() + + /** @see Response._truncation */ + fun _truncation(): JsonField = rawResponse._truncation() + + /** @see Response._usage */ + fun _usage(): JsonField = rawResponse._usage() + + /** @see Response._user */ + fun _user(): JsonField = rawResponse._user() + + /** @see Response._additionalProperties */ + fun _additionalProperties(): Map = rawResponse._additionalProperties() + + private var validated: Boolean = false + + /** @see Response.validate */ + fun validate(): StructuredResponse = apply { + if (validated) { + return@apply + } + + id() + createdAt() + error().ifPresent { it.validate() } + incompleteDetails().ifPresent { it.validate() } + instructions() + metadata().ifPresent { it.validate() } + model().validate() + _object_().let { + if (it != JsonValue.from("response")) { + throw OpenAIInvalidDataException("'object_' is invalid, received $it") + } + } + // `output()` is a different type to that in the delegate class. + output().forEach { it.validate() } + parallelToolCalls() + temperature() + toolChoice().validate() + tools().forEach { it.validate() } + topP() + maxOutputTokens() + previousResponseId() + reasoning().ifPresent { it.validate() } + serviceTier().ifPresent { it.validate() } + status().ifPresent { it.validate() } + text().ifPresent { it.validate() } + truncation().ifPresent { it.validate() } + usage().ifPresent { it.validate() } + user() + validated = true + } + + /** @see Response.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (e: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + return other is StructuredResponse<*> && + responseType == other.responseType && + rawResponse == other.rawResponse + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawResponse) } + + override fun hashCode(): Int = hashCode + + override fun toString(): String = + "${javaClass.simpleName}{responseType=$responseType, rawResponse=$rawResponse}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt new file mode 100644 index 00000000..f2adad97 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt @@ -0,0 +1,528 @@ +package com.openai.models.responses + +import com.openai.core.JsonField +import com.openai.core.JsonSchemaLocalValidation +import com.openai.core.JsonValue +import com.openai.core.checkRequired +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.core.textConfigFromClass +import com.openai.models.ChatModel +import com.openai.models.Reasoning +import com.openai.models.ResponsesModel +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [ResponseCreateParams] that provides a type-safe [Builder] that can record the + * [responseType] used to derive a JSON schema from an arbitrary class when using the _Structured + * Outputs_ feature. When a JSON response is received, it is deserialized to am instance of that + * type. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class that will be used to derive the JSON schema in the request and to + * which the JSON response will be deserialized. + */ +class StructuredResponseCreateParams( + @get:JvmName("responseType") val responseType: Class, + /** + * The raw, underlying response create parameters wrapped by this structured instance of the + * parameters. + */ + @get:JvmName("rawParams") val rawParams: ResponseCreateParams, +) { + + companion object { + /** @see ResponseCreateParams.builder */ + @JvmStatic fun builder() = Builder() + } + + class Builder internal constructor() { + private var responseType: Class? = null + private var paramsBuilder = ResponseCreateParams.builder() + + @JvmSynthetic + internal fun wrap( + responseType: Class, + paramsBuilder: ResponseCreateParams.Builder, + localValidation: JsonSchemaLocalValidation, + ) = apply { + this.responseType = responseType + this.paramsBuilder = paramsBuilder + text(responseType, localValidation) + } + + /** Injects a given `ResponseCreateParams.Builder`. For use only when testing. */ + @JvmSynthetic + internal fun inject(paramsBuilder: ResponseCreateParams.Builder) = apply { + this.paramsBuilder = paramsBuilder + } + + // TODO: Probably not correct, as text config could be overwritten. + /** @see ResponseCreateParams.Builder.body */ + fun body(body: ResponseCreateParams.Body) = apply { paramsBuilder.body(body) } + + /** @see ResponseCreateParams.Builder.input */ + fun input(input: ResponseCreateParams.Input) = apply { paramsBuilder.input(input) } + + /** @see ResponseCreateParams.Builder.input */ + fun input(input: JsonField) = apply { + paramsBuilder.input(input) + } + + /** @see ResponseCreateParams.Builder.input */ + fun input(text: String) = apply { paramsBuilder.input(text) } + + /** @see ResponseCreateParams.Builder.inputOfResponse */ + fun inputOfResponse(response: List) = apply { + paramsBuilder.inputOfResponse(response) + } + + /** @see ResponseCreateParams.Builder.model */ + fun model(model: ResponsesModel) = apply { paramsBuilder.model(model) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(model: JsonField) = apply { paramsBuilder.model(model) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(string: String) = apply { paramsBuilder.model(string) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(chat: ChatModel) = apply { paramsBuilder.model(chat) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(only: ResponsesModel.ResponsesOnlyModel) = apply { paramsBuilder.model(only) } + + /** @see ResponseCreateParams.Builder.include */ + fun include(include: List?) = apply { paramsBuilder.include(include) } + + /** @see ResponseCreateParams.Builder.include */ + fun include(include: Optional>) = apply { + paramsBuilder.include(include) + } + + /** @see ResponseCreateParams.Builder.include */ + fun include(include: JsonField>) = apply { + paramsBuilder.include(include) + } + + /** @see ResponseCreateParams.Builder.addInclude */ + fun addInclude(include: ResponseIncludable) = apply { paramsBuilder.addInclude(include) } + + /** @see ResponseCreateParams.Builder.instructions */ + fun instructions(instructions: String?) = apply { paramsBuilder.instructions(instructions) } + + /** @see ResponseCreateParams.Builder.instructions */ + fun instructions(instructions: Optional) = apply { + paramsBuilder.instructions(instructions) + } + + /** @see ResponseCreateParams.Builder.instructions */ + fun instructions(instructions: JsonField) = apply { + paramsBuilder.instructions(instructions) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: Long?) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: Long) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: Optional) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: JsonField) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.metadata */ + fun metadata(metadata: ResponseCreateParams.Metadata?) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ResponseCreateParams.Builder.metadata */ + fun metadata(metadata: Optional) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ResponseCreateParams.Builder.metadata */ + fun metadata(metadata: JsonField) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean?) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Optional) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: JsonField) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.previousResponseId */ + fun previousResponseId(previousResponseId: String?) = apply { + paramsBuilder.previousResponseId(previousResponseId) + } + + /** @see ResponseCreateParams.Builder.previousResponseId */ + fun previousResponseId(previousResponseId: Optional) = apply { + paramsBuilder.previousResponseId(previousResponseId) + } + + /** @see ResponseCreateParams.Builder.previousResponseId */ + fun previousResponseId(previousResponseId: JsonField) = apply { + paramsBuilder.previousResponseId(previousResponseId) + } + + /** @see ResponseCreateParams.Builder.reasoning */ + fun reasoning(reasoning: Reasoning?) = apply { paramsBuilder.reasoning(reasoning) } + + /** @see ResponseCreateParams.Builder.reasoning */ + fun reasoning(reasoning: Optional) = apply { paramsBuilder.reasoning(reasoning) } + + /** @see ResponseCreateParams.Builder.reasoning */ + fun reasoning(reasoning: JsonField) = apply { + paramsBuilder.reasoning(reasoning) + } + + /** @see ResponseCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: ResponseCreateParams.ServiceTier?) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ResponseCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: Optional) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ResponseCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: JsonField) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: Boolean?) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: Boolean) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: Optional) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: JsonField) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: Double?) = apply { paramsBuilder.temperature(temperature) } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: Double) = apply { paramsBuilder.temperature(temperature) } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: Optional) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: JsonField) = apply { + paramsBuilder.temperature(temperature) + } + + /** + * Sets the text configuration's format to a JSON schema derived from the structure of the + * given class. + * + * @see ResponseCreateParams.Builder.text + */ + @JvmOverloads + fun text( + responseType: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { + this.responseType = responseType + paramsBuilder.text(textConfigFromClass(responseType, localValidation)) + } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: ResponseCreateParams.ToolChoice) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: JsonField) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(options: ToolChoiceOptions) = apply { paramsBuilder.toolChoice(options) } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(types: ToolChoiceTypes) = apply { paramsBuilder.toolChoice(types) } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(function: ToolChoiceFunction) = apply { paramsBuilder.toolChoice(function) } + + /** @see ResponseCreateParams.Builder.tools */ + fun tools(tools: List) = apply { paramsBuilder.tools(tools) } + + /** @see ResponseCreateParams.Builder.tools */ + fun tools(tools: JsonField>) = apply { paramsBuilder.tools(tools) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(tool: Tool) = apply { paramsBuilder.addTool(tool) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(fileSearch: FileSearchTool) = apply { paramsBuilder.addTool(fileSearch) } + + /** @see ResponseCreateParams.Builder.addFileSearchTool */ + fun addFileSearchTool(vectorStoreIds: List) = apply { + paramsBuilder.addFileSearchTool(vectorStoreIds) + } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(function: FunctionTool) = apply { paramsBuilder.addTool(function) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(webSearch: WebSearchTool) = apply { paramsBuilder.addTool(webSearch) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(computerUsePreview: ComputerTool) = apply { + paramsBuilder.addTool(computerUsePreview) + } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: Double?) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: Double) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: Optional) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: JsonField) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.truncation */ + fun truncation(truncation: ResponseCreateParams.Truncation?) = apply { + paramsBuilder.truncation(truncation) + } + + /** @see ResponseCreateParams.Builder.truncation */ + fun truncation(truncation: Optional) = apply { + paramsBuilder.truncation(truncation) + } + + /** @see ResponseCreateParams.Builder.truncation */ + fun truncation(truncation: JsonField) = apply { + paramsBuilder.truncation(truncation) + } + + /** @see ResponseCreateParams.Builder.user */ + fun user(user: String) = apply { paramsBuilder.user(user) } + + /** @see ResponseCreateParams.Builder.user */ + fun user(user: JsonField) = apply { paramsBuilder.user(user) } + + /** @see ResponseCreateParams.Builder.additionalBodyProperties */ + fun additionalBodyProperties(additionalBodyProperties: Map) = apply { + paramsBuilder.additionalBodyProperties(additionalBodyProperties) + } + + /** @see ResponseCreateParams.Builder.putAdditionalBodyProperty */ + fun putAdditionalBodyProperty(key: String, value: JsonValue) = apply { + paramsBuilder.putAdditionalBodyProperty(key, value) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalBodyProperties */ + fun putAllAdditionalBodyProperties(additionalBodyProperties: Map) = + apply { + paramsBuilder.putAllAdditionalBodyProperties(additionalBodyProperties) + } + + /** @see ResponseCreateParams.Builder.removeAdditionalBodyProperty */ + fun removeAdditionalBodyProperty(key: String) = apply { + paramsBuilder.removeAdditionalBodyProperty(key) + } + + /** @see ResponseCreateParams.Builder.removeAllAdditionalBodyProperties */ + fun removeAllAdditionalBodyProperties(keys: Set) = apply { + paramsBuilder.removeAllAdditionalBodyProperties(keys) + } + + /** @see ResponseCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.putAdditionalHeader */ + fun putAdditionalHeader(name: String, value: String) = apply { + paramsBuilder.putAdditionalHeader(name, value) + } + + /** @see ResponseCreateParams.Builder.putAdditionalHeaders */ + fun putAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.putAdditionalHeaders(name, values) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, value: String) = apply { + paramsBuilder.replaceAdditionalHeaders(name, value) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalHeaders(name, values) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.removeAdditionalHeaders */ + fun removeAdditionalHeaders(name: String) = apply { + paramsBuilder.removeAdditionalHeaders(name) + } + + /** @see ResponseCreateParams.Builder.removeAllAdditionalHeaders */ + fun removeAllAdditionalHeaders(names: Set) = apply { + paramsBuilder.removeAllAdditionalHeaders(names) + } + + /** @see ResponseCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: Map>) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.putAdditionalQueryParam */ + fun putAdditionalQueryParam(key: String, value: String) = apply { + paramsBuilder.putAdditionalQueryParam(key, value) + } + + /** @see ResponseCreateParams.Builder.putAdditionalQueryParams */ + fun putAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.putAdditionalQueryParams(key, values) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, value: String) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, value) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, values) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.removeAdditionalQueryParams */ + fun removeAdditionalQueryParams(key: String) = apply { + paramsBuilder.removeAdditionalQueryParams(key) + } + + /** @see ResponseCreateParams.Builder.removeAllAdditionalQueryParams */ + fun removeAllAdditionalQueryParams(keys: Set) = apply { + paramsBuilder.removeAllAdditionalQueryParams(keys) + } + + /** + * Returns an immutable instance of [ResponseCreateParams]. + * + * Further updates to this [Builder] will not mutate the returned instance. + * + * The following fields are required: + * ```java + * .input() + * .model() + * .text() + * ``` + * + * @throws IllegalStateException if any required field is unset. + */ + fun build() = + StructuredResponseCreateParams( + checkRequired("responseType", responseType), + paramsBuilder.build(), + ) + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredResponseCreateParams<*> && + responseType == other.responseType && + rawParams == other.rawParams + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawParams) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawParams=$rawParams}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt new file mode 100644 index 00000000..39cd6b6b --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt @@ -0,0 +1,189 @@ +package com.openai.models.responses + +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import java.util.Objects +import java.util.Optional +import kotlin.jvm.optionals.getOrElse +import kotlin.jvm.optionals.getOrNull + +/** + * A wrapper for [ResponseOutputItem] that provides type-safe access to the [message] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * [message] is called. + */ +class StructuredResponseOutputItem( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawOutputItem") val rawOutputItem: ResponseOutputItem, +) { + private val message by lazy { + rawOutputItem.message().map { StructuredResponseOutputMessage(responseType, it) } + } + + /** @see ResponseOutputItem.message */ + fun message(): Optional> = message + + /** @see ResponseOutputItem.fileSearchCall */ + fun fileSearchCall(): Optional = rawOutputItem.fileSearchCall() + + /** @see ResponseOutputItem.functionCall */ + fun functionCall(): Optional = rawOutputItem.functionCall() + + /** @see ResponseOutputItem.webSearchCall */ + fun webSearchCall(): Optional = rawOutputItem.webSearchCall() + + /** @see ResponseOutputItem.computerCall */ + fun computerCall(): Optional = rawOutputItem.computerCall() + + /** @see ResponseOutputItem.reasoning */ + fun reasoning(): Optional = rawOutputItem.reasoning() + + /** @see ResponseOutputItem.isMessage */ + fun isMessage(): Boolean = message().isPresent + + /** @see ResponseOutputItem.isFileSearchCall */ + fun isFileSearchCall(): Boolean = rawOutputItem.isFileSearchCall() + + /** @see ResponseOutputItem.isFunctionCall */ + fun isFunctionCall(): Boolean = rawOutputItem.isFunctionCall() + + /** @see ResponseOutputItem.isWebSearchCall */ + fun isWebSearchCall(): Boolean = rawOutputItem.isWebSearchCall() + + /** @see ResponseOutputItem.isComputerCall */ + fun isComputerCall(): Boolean = rawOutputItem.isComputerCall() + + /** @see ResponseOutputItem.isReasoning */ + fun isReasoning(): Boolean = rawOutputItem.isReasoning() + + /** @see ResponseOutputItem.asMessage */ + fun asMessage(): StructuredResponseOutputMessage = + message.getOrElse { + // Same behavior as `com.openai.core.getOrThrow` used by the delegate class. + throw OpenAIInvalidDataException("`message` is not present") + } + + /** @see ResponseOutputItem.asFileSearchCall */ + fun asFileSearchCall(): ResponseFileSearchToolCall = rawOutputItem.asFileSearchCall() + + /** @see ResponseOutputItem.asFunctionCall */ + fun asFunctionCall(): ResponseFunctionToolCall = rawOutputItem.asFunctionCall() + + /** @see ResponseOutputItem.asWebSearchCall */ + fun asWebSearchCall(): ResponseFunctionWebSearch = rawOutputItem.asWebSearchCall() + + /** @see ResponseOutputItem.asComputerCall */ + fun asComputerCall(): ResponseComputerToolCall = rawOutputItem.asComputerCall() + + /** @see ResponseOutputItem.asReasoning */ + fun asReasoning(): ResponseReasoningItem = rawOutputItem.asReasoning() + + /** @see ResponseOutputItem._json */ + fun _json(): Optional = rawOutputItem._json() + + /** @see ResponseOutputItem.accept */ + fun accept(visitor: Visitor): R = + when { + isMessage() -> visitor.visitMessage(asMessage()) + isFileSearchCall() -> visitor.visitFileSearchCall(asFileSearchCall()) + isFunctionCall() -> visitor.visitFunctionCall(asFunctionCall()) + isWebSearchCall() -> visitor.visitWebSearchCall(asWebSearchCall()) + isComputerCall() -> visitor.visitComputerCall(asComputerCall()) + isReasoning() -> visitor.visitReasoning(asReasoning()) + else -> visitor.unknown(_json().getOrNull()) + } + + private var validated: Boolean = false + + /** @see ResponseOutputItem.validate */ + fun validate(): StructuredResponseOutputItem = apply { + if (validated) { + return@apply + } + + accept( + object : Visitor { + override fun visitMessage(message: StructuredResponseOutputMessage) { + message.validate() + } + + override fun visitFileSearchCall(fileSearchCall: ResponseFileSearchToolCall) { + fileSearchCall.validate() + } + + override fun visitFunctionCall(functionCall: ResponseFunctionToolCall) { + functionCall.validate() + } + + override fun visitWebSearchCall(webSearchCall: ResponseFunctionWebSearch) { + webSearchCall.validate() + } + + override fun visitComputerCall(computerCall: ResponseComputerToolCall) { + computerCall.validate() + } + + override fun visitReasoning(reasoning: ResponseReasoningItem) { + reasoning.validate() + } + } + ) + validated = true + } + + /** @see ResponseOutputItem.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredResponseOutputItem<*> && + responseType == other.responseType && + rawOutputItem == other.rawOutputItem + } + + override fun hashCode(): Int = Objects.hash(responseType, rawOutputItem) + + override fun toString(): String = + "${javaClass.simpleName}{responseType=$responseType, rawOutputItem=$rawOutputItem}" + + /** @see ResponseOutputItem.Visitor */ + // In keeping with the delegate's `Visitor`, `T` is used to refer to the return type of each + // function. `R` (for "Response") is used to refer to the response type, which is otherwise + // named `T` in the outer class, but confusion here is probably preferable to confusion there. + interface Visitor { + /** @see ResponseOutputItem.Visitor.visitMessage */ + fun visitMessage(message: StructuredResponseOutputMessage): T + + /** @see ResponseOutputItem.Visitor.visitFileSearchCall */ + fun visitFileSearchCall(fileSearchCall: ResponseFileSearchToolCall): T + + /** @see ResponseOutputItem.Visitor.visitFunctionCall */ + fun visitFunctionCall(functionCall: ResponseFunctionToolCall): T + + /** @see ResponseOutputItem.Visitor.visitWebSearchCall */ + fun visitWebSearchCall(webSearchCall: ResponseFunctionWebSearch): T + + /** @see ResponseOutputItem.Visitor.visitComputerCall */ + fun visitComputerCall(computerCall: ResponseComputerToolCall): T + + /** @see ResponseOutputItem.Visitor.visitReasoning */ + fun visitReasoning(reasoning: ResponseReasoningItem): T + + /** @see ResponseOutputItem.Visitor.unknown */ + fun unknown(json: JsonValue?): T { + throw OpenAIInvalidDataException("Unknown ResponseOutputItem: $json") + } + } +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt new file mode 100644 index 00000000..e131df2c --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt @@ -0,0 +1,199 @@ +package com.openai.models.responses + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.responseTypeFromJson +import com.openai.errors.OpenAIInvalidDataException +import java.util.Objects +import java.util.Optional +import kotlin.jvm.optionals.getOrElse +import kotlin.jvm.optionals.getOrNull + +/** + * A wrapper for [ResponseOutputMessage] that provides type-safe access to the [content] when using + * the _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary + * class. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * the output text of the [content] is retrieved. + */ +class StructuredResponseOutputMessage( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawMessage") val rawMessage: ResponseOutputMessage, +) { + /** @see ResponseOutputMessage.id */ + fun id(): String = rawMessage.id() + + private val content by lazy { + rawMessage._content().map { contents -> contents.map { Content(responseType, it) } } + } + + /** @see ResponseOutputMessage.content */ + fun content(): List> = content.getRequired("content") + + /** @see ResponseOutputMessage._role */ + fun _role(): JsonValue = rawMessage._role() + + /** @see ResponseOutputMessage.status */ + fun status(): ResponseOutputMessage.Status = rawMessage.status() + + /** @see ResponseOutputMessage._type */ + fun _type(): JsonValue = rawMessage._type() + + /** @see ResponseOutputMessage._id */ + fun _id(): JsonField = rawMessage._id() + + /** @see ResponseOutputMessage._content */ + fun _content(): JsonField>> = content + + /** @see ResponseOutputMessage._status */ + fun _status(): JsonField = rawMessage._status() + + /** @see ResponseOutputMessage._additionalProperties */ + fun _additionalProperties(): Map = rawMessage._additionalProperties() + + private var validated: Boolean = false + + /** @see ResponseOutputMessage.validate */ + fun validate(): StructuredResponseOutputMessage = apply { + if (validated) { + return@apply + } + + id() + // `content()` is a different type to that in the delegate class. + content().forEach { it.validate() } + _role().let { + if (it != JsonValue.from("assistant")) { + throw OpenAIInvalidDataException("'role' is invalid, received $it") + } + } + status().validate() + _type().let { + if (it != JsonValue.from("message")) { + throw OpenAIInvalidDataException("'type' is invalid, received $it") + } + } + validated = true + } + + /** @see ResponseOutputMessage.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + /** @see ResponseOutputMessage.Content */ + class Content( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawContent") val rawContent: ResponseOutputMessage.Content, + ) { + private val outputText by lazy { + rawContent.outputText().map { responseTypeFromJson(it.text(), responseType) } + } + + /** + * Gets the output text, but deserialized to an instance of the response type class. + * + * @see ResponseOutputMessage.Content.outputText + */ + fun outputText(): Optional = outputText + + /** @see ResponseOutputMessage.Content.refusal */ + fun refusal(): Optional = rawContent.refusal() + + /** @see ResponseOutputMessage.Content.isOutputText */ + // No need to check `outputText`; the delegate can just check the source value is present. + fun isOutputText(): Boolean = rawContent.isOutputText() + + /** @see ResponseOutputMessage.Content.isRefusal */ + fun isRefusal(): Boolean = rawContent.isRefusal() + + /** @see ResponseOutputMessage.Content.asOutputText */ + fun asOutputText(): T = + outputText.getOrElse { + // Same behavior as `com.openai.core.getOrThrow` used by the delegate class. + throw OpenAIInvalidDataException("`outputText` is not present") + } + + /** @see ResponseOutputMessage.Content.asRefusal */ + fun asRefusal(): ResponseOutputRefusal = rawContent.asRefusal() + + /** @see ResponseOutputMessage.Content._json */ + fun _json(): Optional = rawContent._json() + + /** @see ResponseOutputMessage.Content.accept */ + fun accept(visitor: Visitor): R = + when { + outputText.isPresent -> visitor.visitOutputText(outputText.get()) + refusal().isPresent -> visitor.visitRefusal(refusal().get()) + else -> visitor.unknown(_json().getOrNull()) + } + + /** @see ResponseOutputMessage.Content.validate */ + fun validate(): Content = apply { + // The `outputText` object, as it is a user-defined type that is unlikely to have a + // `validate()` function/method, so validate the underlying `ResponseOutputText` from + // which it is derived. That can be done by the delegate class. + rawContent.validate() + } + + /** @see ResponseOutputMessage.Content.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is Content<*> && + rawContent == other.rawContent && + responseType == other.responseType + } + + override fun hashCode(): Int = Objects.hash(rawContent, responseType) + + override fun toString(): String = + "${javaClass.simpleName}{responseType=$responseType, rawContent=$rawContent}" + + /** @see ResponseOutputMessage.Content.Visitor */ + interface Visitor { + /** @see ResponseOutputMessage.Content.Visitor.visitOutputText */ + fun visitOutputText(outputText: T): R + + /** @see ResponseOutputMessage.Content.Visitor.visitRefusal */ + fun visitRefusal(refusal: ResponseOutputRefusal): R + + /** @see ResponseOutputMessage.Content.Visitor.unknown */ + fun unknown(json: JsonValue?): R { + throw OpenAIInvalidDataException("Unknown Content: $json") + } + } + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredResponseOutputMessage<*> && + responseType == other.responseType && + rawMessage == other.rawMessage + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawMessage) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawMessage=$rawMessage}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt index b6e1f962..69570003 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt @@ -12,6 +12,8 @@ import com.openai.models.responses.ResponseCreateParams import com.openai.models.responses.ResponseDeleteParams import com.openai.models.responses.ResponseRetrieveParams import com.openai.models.responses.ResponseStreamEvent +import com.openai.models.responses.StructuredResponse +import com.openai.models.responses.StructuredResponseCreateParams import com.openai.services.blocking.responses.InputItemService interface ResponseService { @@ -42,6 +44,17 @@ interface ResponseService { requestOptions: RequestOptions = RequestOptions.none(), ): Response + /** @see create */ + fun create(params: StructuredResponseCreateParams): StructuredResponse = + create(params, RequestOptions.none()) + + /** @see create */ + fun create( + params: StructuredResponseCreateParams, + requestOptions: RequestOptions = RequestOptions.none(), + ): StructuredResponse = + StructuredResponse(params.responseType, create(params.rawParams, requestOptions)) + /** * Creates a model response. Provide [text](https://platform.openai.com/docs/guides/text) or * [image](https://platform.openai.com/docs/guides/images) inputs to generate diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 6985e05e..29792149 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -65,7 +65,7 @@ interface ChatCompletionService { params: StructuredChatCompletionCreateParams, requestOptions: RequestOptions = RequestOptions.none(), ): StructuredChatCompletion = - StructuredChatCompletion(params.responseFormat, create(params.rawParams, requestOptions)) + StructuredChatCompletion(params.responseType, create(params.rawParams, requestOptions)) /** * **Starting a new project?** We recommend trying diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index 2c1eb885..dd5cdd57 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -447,7 +447,7 @@ internal class StructuredOutputsTest { @Test fun schemaTest_tinyRecursiveSchema() { - @Suppress("unused") class X(val s: String, val x: X) + @Suppress("unused") class X(val s: String, val x: X?) schema = extractSchema(X::class.java) validator.validate(schema) @@ -633,8 +633,7 @@ internal class StructuredOutputsTest { ) validator.validate(schema) - // TODO: Decide if this is the expected behavior, i.e., that it is OK for an "object" schema - // to have no "properties". + // For now, allow that an object may have no properties. Update this if that is revised. assertThat(validator.isValid()).isTrue() } @@ -1415,7 +1414,7 @@ internal class StructuredOutputsTest { fun fromJsonSuccess() { @Suppress("unused") class X(val s: String) - val x = fromJson("{\"s\" : \"hello\"}", X::class.java) + val x = responseTypeFromJson("{\"s\" : \"hello\"}", X::class.java) assertThat(x.s).isEqualTo("hello") } @@ -1425,7 +1424,7 @@ internal class StructuredOutputsTest { @Suppress("unused") class X(val s: String) // Well-formed JSON, but it does not match the schema of class `X`. - assertThatThrownBy { fromJson("{\"wrong\" : \"hello\"}", X::class.java) } + assertThatThrownBy { responseTypeFromJson("{\"wrong\" : \"hello\"}", X::class.java) } .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) .hasMessage("Error parsing JSON: {\"wrong\" : \"hello\"}") } @@ -1435,7 +1434,7 @@ internal class StructuredOutputsTest { @Suppress("unused") class X(val s: String) // Malformed JSON. - assertThatThrownBy { fromJson("{\"truncated", X::class.java) } + assertThatThrownBy { responseTypeFromJson("{\"truncated", X::class.java) } .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) .hasMessage("Error parsing JSON: {\"truncated") } @@ -1444,7 +1443,7 @@ internal class StructuredOutputsTest { fun fromClassEnablesStrictAdherenceToSchema() { @Suppress("unused") class X(val s: String) - val jsonSchema = fromClass(X::class.java) + val jsonSchema = responseFormatFromClass(X::class.java) // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform // to the JSON schema. @@ -1464,7 +1463,7 @@ internal class StructuredOutputsTest { class Z(val y: Y) assertThatNoException().isThrownBy { - fromClass(Z::class.java, JsonSchemaLocalValidation.NO) + responseFormatFromClass(Z::class.java, JsonSchemaLocalValidation.NO) } } @@ -1473,7 +1472,7 @@ internal class StructuredOutputsTest { @Suppress("unused") class X(val s: String) assertThatNoException().isThrownBy { - fromClass(X::class.java, JsonSchemaLocalValidation.YES) + responseFormatFromClass(X::class.java, JsonSchemaLocalValidation.YES) } } @@ -1488,7 +1487,7 @@ internal class StructuredOutputsTest { class Y(val x: X) class Z(val y: Y) - assertThatThrownBy { fromClass(Z::class.java, JsonSchemaLocalValidation.YES) } + assertThatThrownBy { responseFormatFromClass(Z::class.java, JsonSchemaLocalValidation.YES) } .isExactlyInstanceOf(IllegalArgumentException::class.java) .hasMessage( "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + @@ -1509,7 +1508,8 @@ internal class StructuredOutputsTest { class Y(val x: X) class Z(val y: Y) - assertThatThrownBy { fromClass(Z::class.java) } // Use default for `localValidation` flag. + // Use default for `localValidation` flag. + assertThatThrownBy { responseFormatFromClass(Z::class.java) } .isExactlyInstanceOf(IllegalArgumentException::class.java) .hasMessage( "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt new file mode 100644 index 00000000..c62eb8ed --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt @@ -0,0 +1,366 @@ +package com.openai.core + +import java.lang.reflect.Method +import java.util.Optional +import kotlin.reflect.KClass +import kotlin.reflect.KFunction +import kotlin.reflect.KVisibility +import kotlin.reflect.full.declaredFunctions +import kotlin.reflect.jvm.javaMethod +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.fail +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +// Constants for values that can be used in many of the tests as sample input or output values. +// +// Where a function returns `Optional`, `JsonField` or `JsonValue` There is no need to provide +// a value that matches the type ``, a simple `String` value of `"a-string"` will work OK. +internal const val STRING = "a-string" +internal val NULLABLE_STRING: String? = null +internal val OPTIONAL = Optional.of(STRING) +internal val JSON_FIELD = JsonField.of(STRING) +internal val JSON_VALUE = JsonValue.from(STRING) +internal val NULLABLE = null +internal const val BOOLEAN: Boolean = true +internal val NULLABLE_BOOLEAN: Boolean? = null +internal const val LONG: Long = 42L +internal val NULLABLE_LONG: Long? = null +internal const val DOUBLE: Double = 42.0 +internal val NULLABLE_DOUBLE: Double? = null +internal val LIST = listOf(STRING) +internal val SET = setOf(STRING) +internal val MAP = mapOf(STRING to STRING) + +/** + * Defines a test case where a function in a delegator returns a value from a corresponding function + * in a delegate. + */ +internal data class DelegationReadTestCase(val functionName: String, val expectedValue: Any) + +/** + * Defines a test case where a function in a delegator passes its parameters through to a + * corresponding function in a delegate. + */ +// Want `vararg`, so cannot use `data class`. Needs a custom `toString`, anyway. +internal class DelegationWriteTestCase( + val functionName: String, + /** + * The values to pass to the function being tested. If the first input value is `null`, it must + * be the only value. Only the first input value may be `null`, all others must be non-`null`. + * This is not enforced by this class, but is assumed by the related utility functions. + */ + vararg val inputValues: Any?, +) { + /** Gets the string representation that identifies the test function when running JUnit. */ + override fun toString(): String = + "$functionName(${inputValues.joinToString(", ") { + it?.javaClass?.simpleName ?: "null" + }})" +} + +/** A basic class used as the generic type when testing. */ +internal class X(val s: String) { + override fun equals(other: Any?) = other is X && other.s == s + + override fun hashCode() = s.hashCode() +} + +/** + * Checks that all functions in one class have a corresponding function with the same name and + * parameter types in another class. A list of function names that should be allowed as exceptions + * can be given. Non-public functions are ignored, as they are considered to be implementation + * details of each class. + * + * Call this function twice, changing the order of the two classes to ensure that both classes + * contain the same set of functions (barring exceptions), should that be the expectation. + * + * @param subsetClass The class whose functions should be a subset of the functions of the other + * class. + * @param supersetClass The class whose functions should be a superset of the functions of the other + * class. + */ +internal fun checkAllDelegation( + subsetClass: KClass<*>, + supersetClass: KClass<*>, + vararg exceptFunctionNames: String, +) { + assertThat(subsetClass != supersetClass) + .describedAs { "The two classes should not be the same." } + .isTrue + + val subsetFunctions = subsetClass.declaredFunctions + val missingFunctions = mutableListOf>() + + for (subsetFunction in subsetFunctions) { + if (subsetFunction.visibility != KVisibility.PUBLIC) { + continue + } + + if (subsetFunction.name in exceptFunctionNames) { + continue + } + + // Drop the first parameter from each function, as it is the implicit "this" object and has + // the type of the class declaring the function, which will never match. + val supersetFunction = + supersetClass.declaredFunctions.find { + it.name == subsetFunction.name && + it.parameters.drop(1).map { it.type } == + subsetFunction.parameters.drop(1).map { it.type } + } + + if (supersetFunction == null) { + missingFunctions.add(subsetFunction) + } + } + + assertThat(missingFunctions) + .describedAs { + "Function(s) not found in ${supersetClass.simpleName}:\n" + + missingFunctions.joinToString("\n") { " - $it" } + } + .isEmpty() +} + +/** + * Checks that the delegator function calls the corresponding delegate function and no other + * functions on the delegate. The test case defines the function name and the sample return value. + * All functions take no arguments. + */ +internal fun checkOneDelegationRead( + delegator: Any, + mockDelegate: Any, + testCase: DelegationReadTestCase, +) { + // Stub the method in the mock delegate using reflection + val delegateMethod = mockDelegate::class.java.getMethod(testCase.functionName) + `when`(delegateMethod.invoke(mockDelegate)).thenReturn(testCase.expectedValue) + + // Call the corresponding method on the delegator using reflection + val delegatorMethod = delegator::class.java.getMethod(testCase.functionName) + val result = delegatorMethod.invoke(delegator) + + // Verify that the corresponding method on the mock delegate was called exactly once + verify(mockDelegate, times(1)).apply { delegateMethod.invoke(mockDelegate) } + verifyNoMoreInteractions(mockDelegate) + + // Assert that the result matches the expected value + assertThat(result).isEqualTo(testCase.expectedValue) +} + +/** + * Checks that the delegator function calls the corresponding delegate function and no other + * functions on the delegate. The test case defines the function name and sample parameter values. + */ +internal fun checkOneDelegationWrite( + delegator: Any, + mockDelegate: Any, + testCase: DelegationWriteTestCase, +) { + invokeMethod(findDelegationMethod(delegator, testCase), delegator, testCase) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockDelegate, times(1)).apply { + invokeMethod(findDelegationMethod(mockDelegate, testCase), mockDelegate, testCase) + } + verifyNoMoreInteractions(mockDelegate) +} + +private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteTestCase) { + val numParams = testCase.inputValues.size + val inputValue1 = testCase.inputValues[0] + val inputValue2 = testCase.inputValues.getOrNull(1) + + when (numParams) { + 1 -> method.invoke(target, inputValue1) + 2 -> method.invoke(target, inputValue1, inputValue2) + else -> fail { "Unexpected number of function parameters ($numParams)." } + } +} + +/** + * Finds the java method matching the test case's function name and parameter types in the delegator + * or delegate `target`. + */ +internal fun findDelegationMethod(target: Any, testCase: DelegationWriteTestCase): Method { + val numParams = testCase.inputValues.size + val inputValue1: Any? = testCase.inputValues[0] + val inputValue2 = if (numParams > 1) testCase.inputValues[1] else null + + val method = + when (numParams) { + 1 -> + if (inputValue1 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + ) + } else { + // Only the first parameter may be nullable and only if it is the only + // parameter. If the first parameter is nullable, it will be the only function + // of the same name with a nullable first parameter. To handle the potentially + // nullable first parameter, Kotlin reflection is needed. This allows a function + // `f(Boolean)` to be distinguished from `f(Boolean?)`. For the tests, if the + // parameter type is nullable, the parameter value will always be `null` (if + // not, the function with the nullable parameter would not be matched). + // + // Using Kotlin reflection, the first parameter (zero index) is `this` object, + // so start matching from the second parameter onwards. + target::class + .declaredFunctions + .find { + it.name == testCase.functionName && + it.parameters[1].type.isMarkedNullable + } + ?.javaMethod + } + + 2 -> + if (inputValue1 != null && inputValue2 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + toJavaType(inputValue2.javaClass), + ) + } else { + // There are no instances where there are two parameters and one of them is + // nullable. + fail { "Function $testCase second parameter must not be null." } + } + + else -> fail { "Function $testCase has unsupported number of parameters." } + } ?: fail { "Function $testCase cannot be found in $target." } + + // Using `fail` conditionally above, so the compiler knows the code will not continue and can + // infer that `method` is not null. It cannot do that for `assertThat...isNotNull`. + return method +} + +/** Finds a Java method in a class that matches a method name and a list of parameter types. */ +private fun findJavaMethod( + clazz: Class<*>, + methodName: String, + vararg parameterTypes: Class<*>, +): Method? = + clazz.declaredMethods.firstOrNull { method -> + method.name == methodName && + method.parameterTypes.size == parameterTypes.size && + method.parameterTypes.indices.all { index -> + (parameterTypes[index].isPrimitive && + method.parameterTypes[index] == parameterTypes[index]) || + method.parameterTypes[index].isAssignableFrom(parameterTypes[index]) + } + } + +/** + * Returns the Java type to use when matching type parameters for a Java method. The type is the + * type of the input value that will be used when the method is invoked. For most types, the given + * type is returned. However, if the type represents a Kotlin primitive, it will be converted to a + * Java primitive. This allows matching of methods with parameter types that are non-nullable Kotlin + * primitives. If not translated, methods with parameter types that are nullable Kotlin primitives + * would always be matched instead. + */ +private fun toJavaType(type: Class<*>) = + when (type) { + // This only needs to cover the types used in the test cases. + java.lang.Long::class.java -> java.lang.Long.TYPE + java.lang.Boolean::class.java -> java.lang.Boolean.TYPE + java.lang.Double::class.java -> java.lang.Double.TYPE + else -> type + } + +/** + * Checks that all delegating functions in a delegator class have corresponding unit tests. The + * read-only functions should take no parameters; only return a value. + * + * @param delegatorClass The delegator class whose functions are tested. Every named function in + * this class must be identified in one of the given sources of function names or a failure will + * occur. + * @param delegationTestCases The tests cases that identify the names of delegating functions for + * which parameterized unit tests have been defined. + * @param exceptionalTestedFns The names of delegating functions that are tested separately, not as + * parameterized unit tests. This is usually because they require special handling in the test. + * @param nonDelegatingFns The names of functions that do not perform any delegation and for which + * delegation tests are not required. + */ +internal fun checkAllDelegatorReadFunctionsAreTested( + delegatorClass: KClass<*>, + delegationTestCases: List, + exceptionalTestedFns: Set, + nonDelegatingFns: Set, +) { + val testedFns = delegationTestCases.map { it.functionName }.toSet() + exceptionalTestedFns + val delegatorFunctions = delegatorClass.declaredFunctions + val untestedFunctions = + delegatorFunctions.filter { it.name !in testedFns && it.name !in nonDelegatingFns } + + assertThat(untestedFunctions) + .describedAs( + "Delegation is not tested for function(s):\n" + + untestedFunctions.joinToString("\n") { " - $it" } + ) + .isEmpty() +} + +/** + * Checks that all delegating functions in a delegator class have corresponding unit tests. The + * write-only functions should take parameters and return no value. + * + * @param delegatorClass The delegator class whose functions are tested. Every named function in + * this class must be identified in one of the given sources of function names or a failure will + * occur. + * @param delegationTestCases The tests cases that identify the names of delegating functions for + * which parameterized unit tests have been defined. + * @param exceptionalTestedFns The names of delegating functions that are tested separately, not as + * parameterized unit tests. This is usually because they require special handling in the test. + * @param nonDelegatingFns The names of functions that do not perform any delegation and for which + * delegation tests are not required. + */ +internal fun checkAllDelegatorWriteFunctionsAreTested( + delegatorClass: KClass<*>, + delegationTestCases: List, + exceptionalTestedFns: Set, + nonDelegatingFns: Set, +) { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. There are many overloaded functions, so the + // approach here is to build a list (_not_ a set) of all function names and then "subtract" + // those for which tests are defined and see what remains. For example, there could be eight + // `addMessage` functions, so there must be eight tests defined for functions named `addMessage` + // that will be subtracted from the list of functions matching that name. Parameter types are + // not checked, as that is awkward and probably overkill. Therefore, this scheme is not reliable + // if a function is tested more than once. + val testedFns = + (delegationTestCases.map { it.functionName } + exceptionalTestedFns).toMutableList() + // Only interested in the names of the functions (which may contain duplicates): parameters are + // not matched, so any signatures could be misleading when reporting errors. + val delegatorFns = delegatorClass.declaredFunctions.map { it.name }.toMutableList() + + // Making modifications to the list, so clone it with `toList()` before iterating. + for (fnName in delegatorFns.toList()) { + if (fnName in testedFns) { + testedFns.remove(fnName) + delegatorFns.remove(fnName) + } + if (fnName in nonDelegatingFns) { + delegatorFns.remove(fnName) + } + } + + // If there are function names remaining in `delegatorFns`, then there are tests missing. + assertThat(delegatorFns) + .describedAs { "Delegation is not tested for functions $delegatorFns." } + .isEmpty() + + // If there are function names remaining in `testedFns`, then there are more tests than there + // should be. Functions might be tested twice, or there may be tests for functions that have + // since been removed from the delegate (though those tests probably failed). + assertThat(testedFns) + .describedAs { "Unexpected or redundant tests for functions $testedFns." } + .isEmpty() +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt index fb52ffc6..6ae11ba1 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt @@ -367,7 +367,7 @@ internal class ChatCompletionCreateParamsTest { val body = params.rawParams._body() assertThat(params).isInstanceOf(StructuredChatCompletionCreateParams::class.java) - assertThat(params.responseFormat).isEqualTo(X::class.java) + assertThat(params.responseType).isEqualTo(X::class.java) assertThat(body.messages()) .containsExactly( ChatCompletionMessageParam.ofDeveloper( diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt index 4abd66b6..4b289198 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt @@ -1,23 +1,31 @@ package com.openai.models.chat.completions -import com.openai.core.fromClass +import com.openai.core.BOOLEAN +import com.openai.core.DOUBLE +import com.openai.core.DelegationWriteTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.LIST +import com.openai.core.LONG +import com.openai.core.MAP +import com.openai.core.NULLABLE +import com.openai.core.NULLABLE_BOOLEAN +import com.openai.core.NULLABLE_DOUBLE +import com.openai.core.NULLABLE_LONG +import com.openai.core.OPTIONAL +import com.openai.core.SET +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorWriteFunctionsAreTested +import com.openai.core.checkOneDelegationWrite +import com.openai.core.findDelegationMethod import com.openai.core.http.Headers import com.openai.core.http.QueryParams +import com.openai.core.responseFormatFromClass import com.openai.models.ChatModel import com.openai.models.FunctionDefinition -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_FIELD -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_VALUE -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.MESSAGE -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.OPTIONAL -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.STRING -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.X -import java.lang.reflect.Method -import kotlin.collections.plus -import kotlin.reflect.full.declaredFunctions -import kotlin.reflect.jvm.javaMethod -import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test -import org.junit.jupiter.api.fail import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.mockito.Mockito.mock @@ -39,140 +47,11 @@ import org.mockito.kotlin.verify */ internal class StructuredChatCompletionCreateParamsTest { companion object { - private fun checkOneDelegationWrite( - delegator: Any, - mockDelegate: Any, - testCase: DelegationWriteTestCase, - ) { - invokeMethod(findDelegationMethod(delegator, testCase), delegator, testCase) - - // Verify that the corresponding method on the mock delegate was called exactly once. - verify(mockDelegate, times(1)).apply { - invokeMethod(findDelegationMethod(mockDelegate, testCase), mockDelegate, testCase) - } - verifyNoMoreInteractions(mockDelegate) - } - - private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteTestCase) { - val numParams = testCase.inputValues.size - val inputValue1 = testCase.inputValues[0] - val inputValue2 = testCase.inputValues.getOrNull(1) - - when (numParams) { - 1 -> method.invoke(target, inputValue1) - 2 -> method.invoke(target, inputValue1, inputValue2) - else -> fail { "Unexpected number of function parameters ($numParams)." } - } - } - - /** - * Finds the java method matching the test case's function name and parameter types in the - * delegator or delegate `target`. - */ - private fun findDelegationMethod(target: Any, testCase: DelegationWriteTestCase): Method { - val numParams = testCase.inputValues.size - val inputValue1: Any? = testCase.inputValues[0] - val inputValue2 = if (numParams > 1) testCase.inputValues[1] else null - - val method = - when (numParams) { - 1 -> - if (inputValue1 != null) { - findJavaMethod( - target.javaClass, - testCase.functionName, - toJavaType(inputValue1.javaClass), - ) - } else { - // Only the first parameter may be nullable and only if it is the only - // parameter. If the first parameter is nullable, it will be the only - // function of the same name with a nullable first parameter. To handle - // the potentially nullable first parameter, Kotlin reflection is - // needed. This allows a function `f(Boolean)` to be distinguished from - // `f(Boolean?)`. For the tests, if the parameter type is nullable, the - // parameter value will always be `null` (if not, the function with the - // nullable parameter would not be matched). - // - // Using Kotlin reflection, the first parameter (zero index) is `this` - // object, so start matching from the second parameter onwards. - target::class - .declaredFunctions - .find { - it.name == testCase.functionName && - it.parameters[1].type.isMarkedNullable - } - ?.javaMethod - } - 2 -> - if (inputValue1 != null && inputValue2 != null) { - findJavaMethod( - target.javaClass, - testCase.functionName, - toJavaType(inputValue1.javaClass), - toJavaType(inputValue2.javaClass), - ) - } else { - // There are no instances where there are two parameters and one of them - // is nullable. - fail { "Function $testCase second parameter must not be null." } - } - else -> fail { "Function $testCase has unsupported number of parameters." } - } - - // Using `if` and `fail`, so the compiler knows the code will not continue and can infer - // that `delegationMethod` is not null. It cannot do this for `assertThat...isNotNull`. - if (method == null) { - fail { "Function $testCase cannot be found in $target." } - } - - return method - } - - private fun findJavaMethod( - clazz: Class<*>, - methodName: String, - vararg parameterTypes: Class<*>, - ): Method? = - clazz.declaredMethods.firstOrNull { method -> - method.name == methodName && - method.parameterTypes.size == parameterTypes.size && - method.parameterTypes.indices.all { index -> - (parameterTypes[index].isPrimitive && - method.parameterTypes[index] == parameterTypes[index]) || - method.parameterTypes[index].isAssignableFrom(parameterTypes[index]) - } - } - - /** - * Returns the Java type to use when matching type parameters for a Java method. The type is - * the type of the input value that will be used when the method is invoked. For most types, - * the given type is returned. However, if the type represents a Kotlin primitive, it will - * be converted to a Java primitive. This allows matching of methods with parameter types - * that are non-nullable Kotlin primitives. If not translated, methods with parameter types - * that are nullable Kotlin primitives would always be matched instead. - */ - private fun toJavaType(type: Class<*>) = - when (type) { - // This only needs to cover the types used in the test cases. - java.lang.Long::class.java -> java.lang.Long.TYPE - java.lang.Boolean::class.java -> java.lang.Boolean.TYPE - java.lang.Double::class.java -> java.lang.Double.TYPE - else -> type - } - - private val NULLABLE = null - private const val BOOLEAN: Boolean = true - private val NULLABLE_BOOLEAN: Boolean? = null - private const val LONG: Long = 42L - private val NULLABLE_LONG: Long? = null - private const val DOUBLE: Double = 42.0 - private val NULLABLE_DOUBLE: Double? = null - private val LIST = listOf(STRING) - private val SET = setOf(STRING) - private val MAP = mapOf(STRING to STRING) - private val CHAT_MODEL = ChatModel.GPT_4 + private val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + private val USER_MESSAGE_PARAM = ChatCompletionUserMessageParam.builder().content(STRING).build() private val DEV_MESSAGE_PARAM = @@ -226,21 +105,10 @@ internal class StructuredChatCompletionCreateParamsTest { private val HEADERS = Headers.builder().build() private val QUERY_PARAMS = QueryParams.builder().build() - // Want `vararg`, so cannot use `data class`. Need a custom `toString`, anyway. - class DelegationWriteTestCase(val functionName: String, vararg val inputValues: Any?) { - /** - * Gets the string representation that identifies the test function when running JUnit. - */ - override fun toString(): String = - "$functionName(${inputValues.joinToString(", ") { - it?.javaClass?.simpleName ?: "null" - }})" - } - // The list order follows the declaration order in `ChatCompletionCreateParams.Builder` for // easier maintenance. @JvmStatic - fun builderDelegationTestCases() = + private fun builderDelegationTestCases() = listOf( DelegationWriteTestCase("body", PARAMS_BODY), DelegationWriteTestCase("messages", LIST), @@ -394,82 +262,33 @@ internal class StructuredChatCompletionCreateParamsTest { // New instances of the `mockBuilderDelegate` and `builderDelegator` are required for each test // case (each test case runs in its own instance of the test class). - val mockBuilderDelegate: ChatCompletionCreateParams.Builder = + private val mockBuilderDelegate: ChatCompletionCreateParams.Builder = mock(ChatCompletionCreateParams.Builder::class.java) - val builderDelegator = + private val builderDelegator = StructuredChatCompletionCreateParams.builder().inject(mockBuilderDelegate) @Test fun allBuilderDelegateFunctionsExistInDelegator() { // The delegator class does not implement the various `responseFormat` functions of the // delegate class. - StructuredChatCompletionTest.checkAllDelegation( - ChatCompletionCreateParams.Builder::class, - StructuredChatCompletionCreateParams.Builder::class, - "responseFormat", - ) + checkAllDelegation(mockBuilderDelegate::class, builderDelegator::class, "responseFormat") } @Test fun allBuilderDelegatorFunctionsExistInDelegate() { // The delegator implements a different `responseFormat` function from those overloads in // the delegate class. - StructuredChatCompletionTest.checkAllDelegation( - StructuredChatCompletionCreateParams.Builder::class, - ChatCompletionCreateParams.Builder::class, - "responseFormat", - ) + checkAllDelegation(builderDelegator::class, mockBuilderDelegate::class, "responseFormat") } @Test fun allBuilderDelegatorFunctionsAreTested() { - // There are exceptional test cases for some functions. Most other functions are part of the - // list of those using the parameterized test. There are many overloaded functions, so the - // approach here is to build a list (_not_ a set) of all function names and then "subtract" - // those for which tests are defined and see what remains. For example, there are (at this - // time) eight `addMessage` functions, so there must be eight tests defined for functions - // named `addMessage` that will be subtracted from the list of functions matching that name. - // Parameter types are not checked, as that is awkward and probably overkill. Therefore, - // this scheme is not reliable if a function is tested more than once. - val exceptionalTestedFns = listOf("responseFormat") - val testedFns = - (builderDelegationTestCases().map { it.functionName } + exceptionalTestedFns) - .toMutableList() - val nonDelegatingFns = listOf("build", "wrap", "inject") - - val delegatorFns = - StructuredChatCompletionCreateParams.Builder::class.declaredFunctions.toMutableList() - - // Making concurrent modifications to the list, so using an `Iterator`. - val i = delegatorFns.iterator() - - while (i.hasNext()) { - val functionName = i.next().name - - if (functionName in testedFns) { - testedFns.remove(functionName) - i.remove() - } - if (functionName in nonDelegatingFns) { - i.remove() - } - } - - // If there are function names remaining in `delegatorFns`, then there are tests missing. - // Only report the names of the functions not tested: parameters are not matched, so any - // signatures could be misleading. - assertThat(delegatorFns) - .describedAs { - "Delegation is not tested for functions ${delegatorFns.map { it.name }}." - } - .isEmpty() - - // If there are function names remaining in `testedFns`, then there are more tests than - // there should be. Functions might be tested twice, or there may be tests for functions - // that have since been removed from the delegate (though those tests probably failed). - assertThat(testedFns) - .describedAs { "Unexpected or redundant tests for functions $testedFns." } - .isEmpty() + checkAllDelegatorWriteFunctionsAreTested( + builderDelegator::class, + builderDelegationTestCases(), + exceptionalTestedFns = setOf("responseFormat"), + nonDelegatingFns = setOf("build", "wrap", "inject"), + ) } @ParameterizedTest @@ -485,7 +304,7 @@ internal class StructuredChatCompletionCreateParamsTest { val delegatorTestCase = DelegationWriteTestCase("responseFormat", X::class.java) val delegatorMethod = findDelegationMethod(builderDelegator, delegatorTestCase) val mockDelegateTestCase = - DelegationWriteTestCase("responseFormat", fromClass(X::class.java)) + DelegationWriteTestCase("responseFormat", responseFormatFromClass(X::class.java)) val mockDelegateMethod = findDelegationMethod(mockBuilderDelegate, mockDelegateTestCase) delegatorMethod.invoke(builderDelegator, delegatorTestCase.inputValues[0]) diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt index 347788a3..939f9a53 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt @@ -1,15 +1,16 @@ package com.openai.models.chat.completions +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE import com.openai.core.JsonField -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.DelegationReadTestCase -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_FIELD -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_VALUE -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.MESSAGE -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.OPTIONAL -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.X -import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.checkOneDelegationRead +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead import java.util.Optional -import kotlin.reflect.full.declaredFunctions import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest @@ -33,10 +34,13 @@ import org.mockito.kotlin.verify */ internal class StructuredChatCompletionMessageTest { companion object { - // The list order follows the declaration order in `StructuredChatCompletionMessage` for - // easier maintenance. See `StructuredChatCompletionTest` for details on the values used. + private val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + + // The list order follows the declaration order in `ChatCompletionMessage` for easier + // maintenance. @JvmStatic - fun delegationTestCases() = + private fun delegationTestCases() = listOf( // `content()` is a special case and has its own test function. DelegationReadTestCase("refusal", OPTIONAL), @@ -62,46 +66,30 @@ internal class StructuredChatCompletionMessageTest { // New instances of the `mockDelegate` and `delegator` are required for each test case (each // test case runs in its own instance of the test class). - val mockDelegate: ChatCompletionMessage = mock(ChatCompletionMessage::class.java) - val delegator = StructuredChatCompletionMessage(X::class.java, mockDelegate) + private val mockDelegate: ChatCompletionMessage = mock(ChatCompletionMessage::class.java) + private val delegator = StructuredChatCompletionMessage(X::class.java, mockDelegate) @Test fun allDelegateFunctionsExistInDelegator() { - StructuredChatCompletionTest.checkAllDelegation( - ChatCompletionMessage::class, - StructuredChatCompletionMessage::class, - "toBuilder", - "toParam", - ) + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder", "toParam") } @Test fun allDelegatorFunctionsExistInDelegate() { - StructuredChatCompletionTest.checkAllDelegation( - StructuredChatCompletionMessage::class, - ChatCompletionMessage::class, - ) + checkAllDelegation(delegator::class, mockDelegate::class) } @Test fun allDelegatorFunctionsAreTested() { // There are exceptional test cases for some functions. Most other functions are part of the - // list of those using the parameterized test. - val exceptionalTestedFns = setOf("content", "_content") - val testedFns = delegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns - // A few delegator functions do not delegate, so no test function is necessary. - val nonDelegatingFns = listOf("equals", "hashCode", "toString") - - val delegatorFunctions = StructuredChatCompletionMessage::class.declaredFunctions - - for (delegatorFunction in delegatorFunctions) { - assertThat( - delegatorFunction.name in testedFns || - delegatorFunction.name in nonDelegatingFns - ) - .describedAs("Delegation is not tested for function '${delegatorFunction.name}.") - .isTrue - } + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("content", "_content"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) } @ParameterizedTest diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt index af380bbf..40dc509f 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt @@ -1,12 +1,17 @@ package com.openai.models.chat.completions +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE import com.openai.core.JsonField -import com.openai.core.JsonValue +import com.openai.core.LONG +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead import com.openai.errors.OpenAIInvalidDataException -import java.util.Optional -import kotlin.reflect.KClass -import kotlin.reflect.KVisibility -import kotlin.reflect.full.declaredFunctions import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest @@ -30,75 +35,7 @@ import org.mockito.kotlin.verify */ internal class StructuredChatCompletionTest { companion object { - internal fun checkAllDelegation( - delegateClass: KClass<*>, - delegatorClass: KClass<*>, - vararg exceptFunctionNames: String, - ) { - assertThat(delegateClass != delegatorClass) - .describedAs { "Delegate and delegator classes should not be the same." } - .isTrue - - val delegateFunctions = delegateClass.declaredFunctions - - for (delegateFunction in delegateFunctions) { - if (delegateFunction.visibility != KVisibility.PUBLIC) { - // Non-public methods are just implementation details of each class. - continue - } - - if (delegateFunction.name in exceptFunctionNames) { - // Ignore functions that are known exceptions (e.g., `toBuilder`). - continue - } - - // Drop the first parameter from each function, as it is the implicit "this" object - // and has the type of the class declaring the function, which will never match. - val delegatorFunction = - delegatorClass.declaredFunctions.find { - it.name == delegateFunction.name && - it.parameters.drop(1).map { it.type } == - delegateFunction.parameters.drop(1).map { it.type } - } - - assertThat(delegatorFunction != null) - .describedAs { - "Function $delegateFunction is not found in ${delegatorClass.simpleName}." - } - .isTrue - } - } - - internal fun checkOneDelegationRead( - delegator: Any, - mockDelegate: Any, - testCase: DelegationReadTestCase, - ) { - // Stub the method in the mock delegate using reflection - val delegateMethod = mockDelegate::class.java.getMethod(testCase.functionName) - `when`(delegateMethod.invoke(mockDelegate)).thenReturn(testCase.expectedValue) - - // Call the corresponding method on the delegator using reflection - val delegatorMethod = delegator::class.java.getMethod(testCase.functionName) - val result = delegatorMethod.invoke(delegator) - - // Verify that the corresponding method on the mock delegate was called exactly once - verify(mockDelegate, times(1)).apply { delegateMethod.invoke(mockDelegate) } - verifyNoMoreInteractions(mockDelegate) - - // Assert that the result matches the expected value - assertThat(result).isEqualTo(testCase.expectedValue) - } - - // Where a function returns `Optional`, `JsonField` or `JsonValue` There is no need to - // provide a value that matches the type ``, a simple `String` value of `"a-string"` will - // work OK with the test. Constants have been provided for this purpose. - internal const val STRING = "a-string" - - internal val OPTIONAL = Optional.of(STRING) - internal val JSON_FIELD = JsonField.of(STRING) - internal val JSON_VALUE = JsonValue.from(STRING) - internal val MESSAGE = + private val MESSAGE = ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() private val FINISH_REASON = ChatCompletion.Choice.FinishReason.STOP private val CHOICE = @@ -111,16 +48,13 @@ internal class StructuredChatCompletionTest { ) .build() - data class DelegationReadTestCase(val functionName: String, val expectedValue: Any) - - // The list order follows the declaration order in `StructuredChatCompletionMessage` for - // easier maintenance. + // The list order follows the declaration order in `ChatCompletion` for easier maintenance. @JvmStatic - fun delegationTestCases() = + private fun delegationTestCases() = listOf( DelegationReadTestCase("id", STRING), // `choices()` is a special case and has its own test function. - DelegationReadTestCase("created", 123L), + DelegationReadTestCase("created", LONG), DelegationReadTestCase("model", STRING), DelegationReadTestCase("_object_", JSON_VALUE), DelegationReadTestCase("serviceTier", OPTIONAL), @@ -139,10 +73,10 @@ internal class StructuredChatCompletionTest { ) @JvmStatic - fun choiceDelegationTestCases() = + private fun choiceDelegationTestCases() = listOf( DelegationReadTestCase("finishReason", FINISH_REASON), - DelegationReadTestCase("index", 123L), + DelegationReadTestCase("index", LONG), DelegationReadTestCase("logprobs", OPTIONAL), DelegationReadTestCase("_finishReason", JSON_FIELD), // `message()` is a special case and has its own test function. @@ -153,90 +87,58 @@ internal class StructuredChatCompletionTest { // `validate()` and `isValid()` (which calls `validate()`) are tested separately, // as they require special handling. ) - - /** A basic class used as the generic type when testing. */ - internal class X(val s: String) { - override fun equals(other: Any?) = other is X && other.s == s - - override fun hashCode() = s.hashCode() - } } // New instances of the `mockDelegate` and `delegator` are required for each test case (each // test case runs in its own instance of the test class). - val mockDelegate: ChatCompletion = mock(ChatCompletion::class.java) - val delegator = StructuredChatCompletion(X::class.java, mockDelegate) + private val mockDelegate: ChatCompletion = mock(ChatCompletion::class.java) + private val delegator = StructuredChatCompletion(X::class.java, mockDelegate) - val mockChoiceDelegate: ChatCompletion.Choice = mock(ChatCompletion.Choice::class.java) - val choiceDelegator = StructuredChatCompletion.Choice(X::class.java, mockChoiceDelegate) + private val mockChoiceDelegate: ChatCompletion.Choice = mock(ChatCompletion.Choice::class.java) + private val choiceDelegator = + StructuredChatCompletion.Choice(X::class.java, mockChoiceDelegate) @Test fun allChatCompletionDelegateFunctionsExistInDelegator() { - checkAllDelegation(ChatCompletion::class, StructuredChatCompletion::class, "toBuilder") + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder") } @Test fun allChatCompletionDelegatorFunctionsExistInDelegate() { - checkAllDelegation(StructuredChatCompletion::class, ChatCompletion::class) + checkAllDelegation(delegator::class, mockDelegate::class) } @Test fun allChoiceDelegateFunctionsExistInDelegator() { - checkAllDelegation( - ChatCompletion.Choice::class, - StructuredChatCompletion.Choice::class, - "toBuilder", - ) + checkAllDelegation(mockChoiceDelegate::class, choiceDelegator::class, "toBuilder") } @Test fun allChoiceDelegatorFunctionsExistInDelegate() { - checkAllDelegation(StructuredChatCompletion.Choice::class, ChatCompletion.Choice::class) + checkAllDelegation(choiceDelegator::class, mockChoiceDelegate::class) } @Test fun allDelegatorFunctionsAreTested() { // There are exceptional test cases for some functions. Most other functions are part of the - // list of those using the parameterized test. - val exceptionalTestedFns = setOf("choices", "_choices", "validate", "isValid") - val testedFns = delegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns - // A few delegator functions do not delegate, so no test function is necessary. - val nonDelegatingFns = listOf("equals", "hashCode", "toString") - - val delegatorFunctions = StructuredChatCompletion::class.declaredFunctions - - for (delegatorFunction in delegatorFunctions) { - assertThat( - delegatorFunction.name in testedFns || - delegatorFunction.name in nonDelegatingFns - ) - .describedAs("Delegation is not tested for function '${delegatorFunction.name}.") - .isTrue - } + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("choices", "_choices", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) } @Test fun allChoiceDelegatorFunctionsAreTested() { - // There are exceptional test cases for some functions. Most other functions are part of the - // list of those using the parameterized test. - val exceptionalTestedFns = setOf("message", "_message", "validate", "isValid") - val testedFns = - choiceDelegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns - // A few delegator functions do not delegate, so no test function is necessary. - val nonDelegatingFns = listOf("equals", "hashCode", "toString") - - val delegatorFunctions = StructuredChatCompletion.Choice::class.declaredFunctions - - for (delegatorFunction in delegatorFunctions) { - assertThat( - delegatorFunction.name in testedFns || - delegatorFunction.name in nonDelegatingFns - ) - .describedAs( - "Delegation is not tested for function 'Choice.${delegatorFunction.name}." - ) - .isTrue - } + checkAllDelegatorReadFunctionsAreTested( + choiceDelegator::class, + choiceDelegationTestCases(), + exceptionalTestedFns = setOf("message", "_message", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) } @ParameterizedTest @@ -263,7 +165,7 @@ internal class StructuredChatCompletionTest { verify(mockDelegate, times(1))._choices() verifyNoMoreInteractions(mockDelegate) - assertThat(output[0].choice).isEqualTo(CHOICE) + assertThat(output[0].rawChoice).isEqualTo(CHOICE) } @Test @@ -277,7 +179,7 @@ internal class StructuredChatCompletionTest { verify(mockDelegate, times(1))._choices() verifyNoMoreInteractions(mockDelegate) - assertThat(output.getRequired("_choices")[0].choice).isEqualTo(CHOICE) + assertThat(output.getRequired("_choices")[0].rawChoice).isEqualTo(CHOICE) } @Test @@ -339,7 +241,7 @@ internal class StructuredChatCompletionTest { verify(mockChoiceDelegate, times(1))._message() verifyNoMoreInteractions(mockChoiceDelegate) - assertThat(output.chatCompletionMessage).isEqualTo(MESSAGE) + assertThat(output.rawMessage).isEqualTo(MESSAGE) } @Test @@ -353,7 +255,7 @@ internal class StructuredChatCompletionTest { verify(mockChoiceDelegate, times(1))._message() verifyNoMoreInteractions(mockChoiceDelegate) - assertThat(output.getRequired("_message").chatCompletionMessage).isEqualTo(MESSAGE) + assertThat(output.getRequired("_message").rawMessage).isEqualTo(MESSAGE) } @Test diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt new file mode 100644 index 00000000..1eb68ee9 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt @@ -0,0 +1,245 @@ +package com.openai.models.responses + +import com.openai.core.BOOLEAN +import com.openai.core.DOUBLE +import com.openai.core.DelegationWriteTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.LIST +import com.openai.core.LONG +import com.openai.core.MAP +import com.openai.core.NULLABLE +import com.openai.core.NULLABLE_BOOLEAN +import com.openai.core.NULLABLE_DOUBLE +import com.openai.core.NULLABLE_LONG +import com.openai.core.NULLABLE_STRING +import com.openai.core.OPTIONAL +import com.openai.core.SET +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorWriteFunctionsAreTested +import com.openai.core.checkOneDelegationWrite +import com.openai.core.findDelegationMethod +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.core.textConfigFromClass +import com.openai.models.ChatModel +import com.openai.models.Reasoning +import com.openai.models.ResponsesModel +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponseCreateParams] class (delegator) and its delegation of most + * functions to a wrapped [ResponseCreateParams] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseCreateParamsTest { + companion object { + private val CHAT_MODEL = ChatModel.GPT_4O + private val RESPONSES_MODEL = ResponsesModel.ofChat(CHAT_MODEL) + private val RESPONSES_ONLY_MODEL = ResponsesModel.ResponsesOnlyModel.O1_PRO + private val PARAMS_INPUT = ResponseCreateParams.Input.ofText(STRING) + private val PARAMS_BODY = + ResponseCreateParams.Body.builder().input(PARAMS_INPUT).model(RESPONSES_MODEL).build() + + private val INCLUDABLE = ResponseIncludable.of(STRING) + private val METADATA = ResponseCreateParams.Metadata.builder().build() + private val SERVICE_TIER = ResponseCreateParams.ServiceTier.AUTO + private val REASONING = Reasoning.builder().build() + + private val TOOL_CHOICE_TYPE = ToolChoiceTypes.Type.FILE_SEARCH + private val TOOL_CHOICE_TYPES = ToolChoiceTypes.builder().type(TOOL_CHOICE_TYPE).build() + private val TOOL_CHOICE = ResponseCreateParams.ToolChoice.ofTypes(TOOL_CHOICE_TYPES) + private val TOOL_CHOICE_OPTIONS = ToolChoiceOptions.AUTO + private val TOOL_CHOICE_FUNCTION = ToolChoiceFunction.builder().name(STRING).build() + + private val FUNCTION_TOOL = + FunctionTool.builder().name(STRING).parameters(NULLABLE).strict(BOOLEAN).build() + private val FILE_SEARCH_TOOL = FileSearchTool.builder().vectorStoreIds(LIST).build() + private val WEB_SEARCH_TOOL = + WebSearchTool.builder().type(WebSearchTool.Type.WEB_SEARCH_PREVIEW).build() + private val COMPUTER_TOOL = + ComputerTool.builder() + .displayWidth(LONG) + .displayHeight(LONG) + .environment(ComputerTool.Environment.LINUX) + .build() + private val TOOL = Tool.ofFunction(FUNCTION_TOOL) + + private val HEADERS = Headers.builder().build() + private val QUERY_PARAMS = QueryParams.builder().build() + + // The list order follows the declaration order in `ResponseCreateParams.Builder` for + // easier maintenance. + @JvmStatic + private fun builderDelegationTestCases() = + listOf( + DelegationWriteTestCase("body", PARAMS_BODY), + DelegationWriteTestCase("input", PARAMS_INPUT), + DelegationWriteTestCase("input", JSON_FIELD), + DelegationWriteTestCase("input", STRING), + DelegationWriteTestCase("inputOfResponse", LIST), + DelegationWriteTestCase("model", RESPONSES_MODEL), + DelegationWriteTestCase("model", JSON_FIELD), + DelegationWriteTestCase("model", STRING), + DelegationWriteTestCase("model", CHAT_MODEL), + DelegationWriteTestCase("model", RESPONSES_ONLY_MODEL), + DelegationWriteTestCase("include", LIST), + DelegationWriteTestCase("include", OPTIONAL), + DelegationWriteTestCase("include", JSON_FIELD), + DelegationWriteTestCase("addInclude", INCLUDABLE), + DelegationWriteTestCase("instructions", NULLABLE_STRING), + DelegationWriteTestCase("instructions", OPTIONAL), + DelegationWriteTestCase("instructions", JSON_FIELD), + DelegationWriteTestCase("maxOutputTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxOutputTokens", LONG), + DelegationWriteTestCase("maxOutputTokens", OPTIONAL), + DelegationWriteTestCase("maxOutputTokens", JSON_FIELD), + DelegationWriteTestCase("metadata", METADATA), + DelegationWriteTestCase("metadata", OPTIONAL), + DelegationWriteTestCase("metadata", JSON_FIELD), + DelegationWriteTestCase("parallelToolCalls", NULLABLE_BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", OPTIONAL), + DelegationWriteTestCase("parallelToolCalls", JSON_FIELD), + DelegationWriteTestCase("previousResponseId", NULLABLE_STRING), + DelegationWriteTestCase("previousResponseId", OPTIONAL), + DelegationWriteTestCase("previousResponseId", JSON_FIELD), + DelegationWriteTestCase("reasoning", REASONING), + DelegationWriteTestCase("reasoning", OPTIONAL), + DelegationWriteTestCase("reasoning", JSON_FIELD), + DelegationWriteTestCase("serviceTier", SERVICE_TIER), + DelegationWriteTestCase("serviceTier", OPTIONAL), + DelegationWriteTestCase("serviceTier", JSON_FIELD), + DelegationWriteTestCase("store", NULLABLE_BOOLEAN), + DelegationWriteTestCase("store", BOOLEAN), + DelegationWriteTestCase("store", OPTIONAL), + DelegationWriteTestCase("store", JSON_FIELD), + DelegationWriteTestCase("temperature", NULLABLE_DOUBLE), + DelegationWriteTestCase("temperature", DOUBLE), + DelegationWriteTestCase("temperature", OPTIONAL), + DelegationWriteTestCase("temperature", JSON_FIELD), + // `text()` is a special case and has its own unit tests. + DelegationWriteTestCase("toolChoice", TOOL_CHOICE), + DelegationWriteTestCase("toolChoice", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTIONS), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_TYPES), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_FUNCTION), + DelegationWriteTestCase("tools", LIST), + DelegationWriteTestCase("tools", JSON_FIELD), + DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("addTool", FILE_SEARCH_TOOL), + DelegationWriteTestCase("addFileSearchTool", LIST), + DelegationWriteTestCase("addTool", FUNCTION_TOOL), + DelegationWriteTestCase("addTool", WEB_SEARCH_TOOL), + DelegationWriteTestCase("addTool", COMPUTER_TOOL), + DelegationWriteTestCase("topP", NULLABLE_DOUBLE), + DelegationWriteTestCase("topP", DOUBLE), + DelegationWriteTestCase("topP", OPTIONAL), + DelegationWriteTestCase("topP", JSON_FIELD), + DelegationWriteTestCase("truncation", NULLABLE), + DelegationWriteTestCase("truncation", OPTIONAL), + DelegationWriteTestCase("truncation", JSON_FIELD), + DelegationWriteTestCase("user", STRING), + DelegationWriteTestCase("user", JSON_FIELD), + DelegationWriteTestCase("additionalBodyProperties", MAP), + DelegationWriteTestCase("putAdditionalBodyProperty", STRING, JSON_VALUE), + DelegationWriteTestCase("putAllAdditionalBodyProperties", MAP), + DelegationWriteTestCase("removeAdditionalBodyProperty", STRING), + DelegationWriteTestCase("removeAllAdditionalBodyProperties", SET), + DelegationWriteTestCase("additionalHeaders", HEADERS), + DelegationWriteTestCase("additionalHeaders", MAP), + DelegationWriteTestCase("putAdditionalHeader", STRING, STRING), + DelegationWriteTestCase("putAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("putAllAdditionalHeaders", MAP), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("replaceAllAdditionalHeaders", MAP), + DelegationWriteTestCase("removeAdditionalHeaders", STRING), + DelegationWriteTestCase("removeAllAdditionalHeaders", SET), + DelegationWriteTestCase("additionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("additionalQueryParams", MAP), + DelegationWriteTestCase("putAdditionalQueryParam", STRING, STRING), + DelegationWriteTestCase("putAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("putAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("removeAdditionalQueryParams", STRING), + DelegationWriteTestCase("removeAllAdditionalQueryParams", SET), + ) + } + + // New instances of the `mockBuilderDelegate` and `builderDelegator` are required for each test + // case (each test case runs in its own instance of the test class). + private val mockBuilderDelegate: ResponseCreateParams.Builder = + mock(ResponseCreateParams.Builder::class.java) + private val builderDelegator = + StructuredResponseCreateParams.builder().inject(mockBuilderDelegate) + + @Test + fun allBuilderDelegateFunctionsExistInDelegator() { + // The delegator class does not implement the various `text` functions of the delegate + // class. + checkAllDelegation(mockBuilderDelegate::class, builderDelegator::class, "text") + } + + @Test + fun allBuilderDelegatorFunctionsExistInDelegate() { + // The delegator implements a different `text` function from those overloads in the delegate + // class. + checkAllDelegation(builderDelegator::class, mockBuilderDelegate::class, "text") + } + + @Test + fun allBuilderDelegatorFunctionsAreTested() { + checkAllDelegatorWriteFunctionsAreTested( + builderDelegator::class, + builderDelegationTestCases(), + exceptionalTestedFns = setOf("text"), + nonDelegatingFns = setOf("build", "wrap", "inject"), + ) + } + + @ParameterizedTest + @MethodSource("builderDelegationTestCases") + fun `delegation of Builder write functions`(testCase: DelegationWriteTestCase) { + checkOneDelegationWrite(builderDelegator, mockBuilderDelegate, testCase) + } + + @Test + fun `delegation of text`() { + // Special unit test case as the delegator method signature does not match that of the + // delegate method. + val delegatorTestCase = DelegationWriteTestCase("text", X::class.java) + val delegatorMethod = findDelegationMethod(builderDelegator, delegatorTestCase) + val mockDelegateTestCase = + DelegationWriteTestCase("text", textConfigFromClass(X::class.java)) + val mockDelegateMethod = findDelegationMethod(mockBuilderDelegate, mockDelegateTestCase) + + delegatorMethod.invoke(builderDelegator, delegatorTestCase.inputValues[0]) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockBuilderDelegate, times(1)).apply { + mockDelegateMethod.invoke(mockBuilderDelegate, mockDelegateTestCase.inputValues[0]) + } + verifyNoMoreInteractions(mockBuilderDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt new file mode 100644 index 00000000..2ba6fa03 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt @@ -0,0 +1,204 @@ +package com.openai.models.responses + +import com.openai.core.DelegationReadTestCase +import com.openai.core.LIST +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponseOutputItem] class (delegator) and its delegation of most + * functions to a wrapped [ResponseOutputItem] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseOutputItemTest { + companion object { + private val FILE_SEARCH_TOOL_CALL = + ResponseFileSearchToolCall.builder() + .id(STRING) + .queries(LIST) + .status(ResponseFileSearchToolCall.Status.COMPLETED) + .build() + private val FUNCTION_TOOL_CALL = + ResponseFunctionToolCall.builder().arguments(STRING).callId(STRING).name(STRING).build() + private val FUNCTION_WEB_SEARCH = + ResponseFunctionWebSearch.builder() + .id(STRING) + .status(ResponseFunctionWebSearch.Status.COMPLETED) + .build() + private val COMPUTER_TOOL_CALL = + ResponseComputerToolCall.builder() + .id(STRING) + .action(ResponseComputerToolCall.Action.ofWait()) + .callId(STRING) + .pendingSafetyChecks(listOf()) + .status(ResponseComputerToolCall.Status.COMPLETED) + .type(ResponseComputerToolCall.Type.COMPUTER_CALL) + .build() + private val REASONING_ITEM = + ResponseReasoningItem.builder().id(STRING).summary(listOf()).build() + private val MESSAGE = + ResponseOutputMessage.builder() + .id(STRING) + .content(listOf()) + .status(ResponseOutputMessage.Status.COMPLETED) + .build() + + // The list order follows the declaration order in `ResponseOutputItem` for easier + // maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + // `message()` is a special case and has its own test function. + DelegationReadTestCase("fileSearchCall", OPTIONAL), + DelegationReadTestCase("functionCall", OPTIONAL), + DelegationReadTestCase("webSearchCall", OPTIONAL), + DelegationReadTestCase("computerCall", OPTIONAL), + DelegationReadTestCase("reasoning", OPTIONAL), + // `isMessage()` is a special case and has its own test function. + // For the Boolean functions, call each in turn with both `true` and `false` to + // ensure that a return value is not hard-coded. + DelegationReadTestCase("isFileSearchCall", true), + DelegationReadTestCase("isFileSearchCall", false), + DelegationReadTestCase("isFunctionCall", true), + DelegationReadTestCase("isFunctionCall", false), + DelegationReadTestCase("isWebSearchCall", true), + DelegationReadTestCase("isWebSearchCall", false), + DelegationReadTestCase("isComputerCall", true), + DelegationReadTestCase("isComputerCall", false), + DelegationReadTestCase("isReasoning", true), + DelegationReadTestCase("isReasoning", false), + // `asMessage()` is a special case and has its own test function. + DelegationReadTestCase("asFileSearchCall", FILE_SEARCH_TOOL_CALL), + DelegationReadTestCase("asFunctionCall", FUNCTION_TOOL_CALL), + DelegationReadTestCase("asWebSearchCall", FUNCTION_WEB_SEARCH), + DelegationReadTestCase("asComputerCall", COMPUTER_TOOL_CALL), + DelegationReadTestCase("asReasoning", REASONING_ITEM), + DelegationReadTestCase("_json", OPTIONAL), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: ResponseOutputItem = mock(ResponseOutputItem::class.java) + private val delegator = StructuredResponseOutputItem(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + // `toBuilder()` is deliberately not implemented. `accept()` has a different signature. + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder", "accept") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + // `accept()` has a different signature. + checkAllDelegation(delegator::class, mockDelegate::class, "accept") + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = + setOf("message", "asMessage", "isMessage", "validate", "isValid", "accept"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of message`() { + // Input and output are different types, so this test is an exceptional case. + // The delegator's `message()` delegates to the delegate's `message()` indirectly via the + // delegator's `message` field initializer. + val input = Optional.of(MESSAGE) + `when`(mockDelegate.message()).thenReturn(input) + val output = delegator.message() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.get().rawMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of asMessage`() { + // Delegation function names do not match, so this test is an exceptional case. + // The delegator's `asMessage()` delegates to the delegate's `message()` (without the "as") + // indirectly via the delegator's `message` field initializer. + val input = Optional.of(MESSAGE) + `when`(mockDelegate.message()).thenReturn(input) + val output = delegator.asMessage() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.rawMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of isMessage`() { + // Delegation function names do not match, so this test is an exceptional case. + // The delegator's `isMessage()` delegates to the delegate's `message()` (without the "is") + // indirectly via the delegator's `message` field initializer. + val input = Optional.of(MESSAGE) + `when`(mockDelegate.message()).thenReturn(input) + val output = delegator.isMessage() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of validate`() { + `when`(mockDelegate.message()).thenReturn(Optional.of(MESSAGE)) + + delegator.validate() + + // Delegator's `validate()` does not call delegate's `validate()`. `message()` is called + // indirectly via the `message` field initializer. + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of isValid`() { + // `isValid` calls `validate()`, so the test is similar to that for `validate()`. + `when`(mockDelegate.message()).thenReturn(Optional.of(MESSAGE)) + + delegator.isValid() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt new file mode 100644 index 00000000..ee3c3a57 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt @@ -0,0 +1,298 @@ +package com.openai.models.responses + +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.MAP +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import com.openai.errors.OpenAIInvalidDataException +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponseOutputMessage] class (delegator) and its delegation of most + * functions to a wrapped [ResponseOutputMessage] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseOutputMessageTest { + companion object { + private val MESSAGE_STATUS = ResponseOutputMessage.Status.COMPLETED + private val OUTPUT_TEXT = + ResponseOutputText.builder().annotations(listOf()).text(STRING).build() + private val OUTPUT_REFUSAL = ResponseOutputRefusal.builder().refusal(STRING).build() + private val CONTENT = ResponseOutputMessage.Content.ofOutputText(OUTPUT_TEXT) + + // The list order follows the declaration order in `ResponseOutputMessage` for easier + // maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + // `content()` is a special case and has its own test function. + DelegationReadTestCase("_role", JSON_VALUE), + DelegationReadTestCase("status", MESSAGE_STATUS), + DelegationReadTestCase("_type", JSON_VALUE), + DelegationReadTestCase("_id", JSON_FIELD), + // `_content()` is a special case and has its own test function. + DelegationReadTestCase("_status", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", MAP), + ) + + // The list order follows the declaration order in `ResponseOutputMessage.Content` for + // easier maintenance. + @JvmStatic + private fun contentDelegationTestCases() = + listOf( + // `outputText()` is a special case and has its own test function. + DelegationReadTestCase("refusal", OPTIONAL), + // For the Boolean functions, pass both `true` and `false` to ensure that one value + // is not hard-coded. + DelegationReadTestCase("isOutputText", true), + DelegationReadTestCase("isOutputText", false), + DelegationReadTestCase("isRefusal", true), + DelegationReadTestCase("isRefusal", false), + // `asOutputText()` is a special case and has its own test function. + DelegationReadTestCase("asRefusal", OUTPUT_REFUSAL), + DelegationReadTestCase("_json", OPTIONAL), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: ResponseOutputMessage = mock(ResponseOutputMessage::class.java) + private val delegator = StructuredResponseOutputMessage(X::class.java, mockDelegate) + + private val contentMockDelegate: ResponseOutputMessage.Content = + mock(ResponseOutputMessage.Content::class.java) + private val contentDelegator = + StructuredResponseOutputMessage.Content(X::class.java, contentMockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + checkAllDelegation(delegator::class, mockDelegate::class) + } + + @Test + fun allContentDelegateFunctionsExistInDelegator() { + // The `Content.accept()` function in the delegator takes a different type than that in the + // delegate, so there is no delegation from the former to the latter. `Content.toBuilder` is + // deliberately not implemented. + checkAllDelegation( + contentMockDelegate::class, + contentDelegator::class, + "toBuilder", + "accept", + ) + } + + @Test + fun allContentDelegatorFunctionsExistInDelegate() { + // The `Content.accept()` function in the delegator takes a different type than that in the + // delegate, so there is no delegation from the former to the latter. + checkAllDelegation(contentDelegator::class, contentMockDelegate::class, "accept") + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("content", "_content", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @Test + fun allContentDelegatorFunctionsAreTested() { + checkAllDelegatorReadFunctionsAreTested( + contentDelegator::class, + contentDelegationTestCases(), + exceptionalTestedFns = + setOf("outputText", "asOutputText", "validate", "isValid", "accept"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @ParameterizedTest + @MethodSource("contentDelegationTestCases") + fun `delegation of Content functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(contentDelegator, contentMockDelegate, testCase) + } + + @Test + fun `delegation of content`() { + // Input and output are different types, so this test is an exceptional case. + // `content()` (without an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of(listOf(CONTENT)) + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator.content() // Without an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].rawContent).isEqualTo(CONTENT) + } + + @Test + fun `delegation of _content`() { + // Input and output are different types, so this test is an exceptional case. + // `_content()` (with an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of(listOf(CONTENT)) + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator._content() // With an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("content")[0].rawContent).isEqualTo(CONTENT) + } + + @Test + fun `delegation of validate`() { + `when`(mockDelegate._content()).thenReturn(JsonField.of(listOf(CONTENT))) + `when`(mockDelegate._role()).thenReturn(JsonValue.from("assistant")) + `when`(mockDelegate.status()).thenReturn(ResponseOutputMessage.Status.COMPLETED) + `when`(mockDelegate._type()).thenReturn(JsonValue.from("message")) + + delegator.validate() + + // Delegator's `validate()` does not call delegate's `validate()`. `_content` is called + // indirectly via the `content` field initializer. + verify(mockDelegate, times(1))._content() + verify(mockDelegate, times(1)).id() + verify(mockDelegate, times(1))._role() + verify(mockDelegate, times(1)).status() + verify(mockDelegate, times(1))._type() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of isValid`() { + // `isValid` calls `validate()`, so the test is similar to that for `validate()`. + `when`(mockDelegate._content()).thenReturn(JsonField.of(listOf(CONTENT))) + `when`(mockDelegate._role()).thenReturn(JsonValue.from("assistant")) + `when`(mockDelegate.status()).thenReturn(ResponseOutputMessage.Status.COMPLETED) + `when`(mockDelegate._type()).thenReturn(JsonValue.from("message")) + + delegator.isValid() + + verify(mockDelegate, times(1))._content() + verify(mockDelegate, times(1)).id() + verify(mockDelegate, times(1))._role() + verify(mockDelegate, times(1)).status() + verify(mockDelegate, times(1))._type() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of Content outputText`() { + // Input and output are different types, so this test is an exceptional case. The + // delegator's `outputText()` delegates to the delegate's `outputText()` indirectly via the + // `outputText` field initializer. + val input = + Optional.of( + ResponseOutputText.builder() + .annotations(listOf()) + .text("{\"s\" : \"hello\"}") + .build() + ) + `when`(contentMockDelegate.outputText()).thenReturn(input) + val output = contentDelegator.outputText() + + verify(contentMockDelegate, times(1)).outputText() + verifyNoMoreInteractions(contentMockDelegate) + + assertThat(output).isEqualTo(Optional.of(X("hello"))) + } + + @Test + fun `delegation of Content asOutputText`() { + // Input and output are different types, so this test is an exceptional case. The + // delegator's `asOutputText()` delegates to the delegate's `outputText()` indirectly via + // the + // `outputText` field initializer. + val input = + Optional.of( + ResponseOutputText.builder() + .annotations(listOf()) + .text("{\"s\" : \"hello\"}") + .build() + ) + `when`(contentMockDelegate.outputText()).thenReturn(input) + val output = contentDelegator.asOutputText() + + verify(contentMockDelegate, times(1)).outputText() + verifyNoMoreInteractions(contentMockDelegate) + + assertThat(output).isEqualTo(X("hello")) + } + + @Test + fun `delegation of Content asOutputText missing`() { + val input = Optional.ofNullable(null) + `when`(contentMockDelegate.outputText()).thenReturn(input) + + assertThatThrownBy { contentDelegator.asOutputText() } + .isInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("`outputText` is not present") + + verify(contentMockDelegate, times(1)).outputText() + verifyNoMoreInteractions(contentMockDelegate) + } + + @Test + fun `delegation of Content validate`() { + // No values or passed and only `this` is returned. + contentDelegator.validate() + + verify(contentMockDelegate, times(1)).validate() + verifyNoMoreInteractions(contentMockDelegate) + } + + @Test + fun `delegation of Content isValid`() { + contentDelegator.isValid() + + // `isValid()` calls `validate`, which then calls the mock delegate. + verify(contentMockDelegate, times(1)).validate() + verifyNoMoreInteractions(contentMockDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt new file mode 100644 index 00000000..365476e7 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt @@ -0,0 +1,245 @@ +package com.openai.models.responses + +import com.openai.core.BOOLEAN +import com.openai.core.DOUBLE +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.LIST +import com.openai.core.MAP +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import com.openai.models.ChatModel +import com.openai.models.ResponsesModel +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponse] class (delegator) and its delegation of most functions to + * a wrapped [Response] (delegate). The tests include confirmation of the following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseTest { + companion object { + private val RESPONSES_MODEL = ResponsesModel.ofChat(ChatModel.GPT_4O) + + private val TOOL_CHOICE = + Response.ToolChoice.ofFunction(ToolChoiceFunction.builder().name(STRING).build()) + // A reasoning item is probably the simplest one to create. + private val OUTPUT_ITEM = + ResponseOutputItem.ofReasoning( + ResponseReasoningItem.builder() + .id(STRING) + .summary(listOf(ResponseReasoningItem.Summary.builder().text(STRING).build())) + .build() + ) + + // The list order follows the declaration order in `Response` for easier maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + DelegationReadTestCase("createdAt", DOUBLE), + DelegationReadTestCase("error", OPTIONAL), + DelegationReadTestCase("incompleteDetails", OPTIONAL), + DelegationReadTestCase("instructions", OPTIONAL), + DelegationReadTestCase("metadata", OPTIONAL), + DelegationReadTestCase("model", RESPONSES_MODEL), + DelegationReadTestCase("_object_", JSON_VALUE), + // `output()` is a special case and has its own test function. + DelegationReadTestCase("parallelToolCalls", BOOLEAN), + DelegationReadTestCase("temperature", OPTIONAL), + DelegationReadTestCase("toolChoice", TOOL_CHOICE), + DelegationReadTestCase("tools", LIST), + DelegationReadTestCase("topP", OPTIONAL), + DelegationReadTestCase("maxOutputTokens", OPTIONAL), + DelegationReadTestCase("previousResponseId", OPTIONAL), + DelegationReadTestCase("reasoning", OPTIONAL), + DelegationReadTestCase("serviceTier", OPTIONAL), + DelegationReadTestCase("status", OPTIONAL), + DelegationReadTestCase("text", OPTIONAL), + DelegationReadTestCase("truncation", OPTIONAL), + DelegationReadTestCase("usage", OPTIONAL), + DelegationReadTestCase("user", OPTIONAL), + DelegationReadTestCase("_id", JSON_FIELD), + DelegationReadTestCase("_createdAt", JSON_FIELD), + DelegationReadTestCase("_error", JSON_FIELD), + DelegationReadTestCase("_incompleteDetails", JSON_FIELD), + DelegationReadTestCase("_instructions", JSON_FIELD), + DelegationReadTestCase("_metadata", JSON_FIELD), + DelegationReadTestCase("_model", JSON_FIELD), + // `_output()` is a special case and has its own test function. + DelegationReadTestCase("_parallelToolCalls", JSON_FIELD), + DelegationReadTestCase("_temperature", JSON_FIELD), + DelegationReadTestCase("_toolChoice", JSON_FIELD), + DelegationReadTestCase("_tools", JSON_FIELD), + DelegationReadTestCase("_topP", JSON_FIELD), + DelegationReadTestCase("_maxOutputTokens", JSON_FIELD), + DelegationReadTestCase("_previousResponseId", JSON_FIELD), + DelegationReadTestCase("_reasoning", JSON_FIELD), + DelegationReadTestCase("_serviceTier", JSON_FIELD), + DelegationReadTestCase("_status", JSON_FIELD), + DelegationReadTestCase("_text", JSON_FIELD), + DelegationReadTestCase("_truncation", JSON_FIELD), + DelegationReadTestCase("_usage", JSON_FIELD), + DelegationReadTestCase("_user", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", MAP), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: Response = mock(Response::class.java) + private val delegator = StructuredResponse(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + checkAllDelegation(delegator::class, mockDelegate::class) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("output", "_output", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of output`() { + // Input and output are different types, so this test is an exceptional case. + // `output()` (without an underscore) delegates to `_output()` (with an underscore) + // indirectly via the `output` field initializer. + val input = JsonField.of(listOf(OUTPUT_ITEM)) + `when`(mockDelegate._output()).thenReturn(input) + val output = delegator.output() // Without an underscore. + + verify(mockDelegate, times(1))._output() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].rawOutputItem).isEqualTo(OUTPUT_ITEM) + } + + @Test + fun `delegation of _output`() { + // Input and output are different types, so this test is an exceptional case. + // `_output()` delegates to `_output()` indirectly via the `output` field initializer. + val input = JsonField.of(listOf(OUTPUT_ITEM)) + `when`(mockDelegate._output()).thenReturn(input) + val output = delegator._output() // With an underscore. + + verify(mockDelegate, times(1))._output() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("_output")[0].rawOutputItem).isEqualTo(OUTPUT_ITEM) + } + + @Test + fun `delegation of validate`() { + `when`(mockDelegate.model()).thenReturn(RESPONSES_MODEL) + `when`(mockDelegate._object_()).thenReturn(JsonValue.from("response")) + `when`(mockDelegate._output()).thenReturn(JsonField.of(listOf(OUTPUT_ITEM))) + `when`(mockDelegate.toolChoice()).thenReturn(TOOL_CHOICE) + + delegator.validate() + + // Delegator's `validate()` does not call delegate's `validate()`. `_content` is called + // indirectly via the `content` field initializer. + verify(mockDelegate, times(1)).id() + verify(mockDelegate, times(1)).createdAt() + verify(mockDelegate, times(1)).error() + verify(mockDelegate, times(1)).incompleteDetails() + verify(mockDelegate, times(1)).instructions() + verify(mockDelegate, times(1)).metadata() + verify(mockDelegate, times(1)).model() + verify(mockDelegate, times(1))._object_() + verify(mockDelegate, times(1))._output() // Indirect + verify(mockDelegate, times(1)).parallelToolCalls() + verify(mockDelegate, times(1)).temperature() + verify(mockDelegate, times(1)).toolChoice() + verify(mockDelegate, times(1)).tools() + verify(mockDelegate, times(1)).topP() + verify(mockDelegate, times(1)).maxOutputTokens() + verify(mockDelegate, times(1)).previousResponseId() + verify(mockDelegate, times(1)).reasoning() + verify(mockDelegate, times(1)).serviceTier() + verify(mockDelegate, times(1)).status() + verify(mockDelegate, times(1)).text() + verify(mockDelegate, times(1)).truncation() + verify(mockDelegate, times(1)).usage() + verify(mockDelegate, times(1)).user() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of isValid`() { + `when`(mockDelegate.model()).thenReturn(RESPONSES_MODEL) + `when`(mockDelegate._object_()).thenReturn(JsonValue.from("response")) + `when`(mockDelegate._output()).thenReturn(JsonField.of(listOf(OUTPUT_ITEM))) + `when`(mockDelegate.toolChoice()).thenReturn(TOOL_CHOICE) + + // `isValid()` calls `validate()`, so the test is more-or-less the same. + delegator.isValid() + + verify(mockDelegate, times(1)).id() + verify(mockDelegate, times(1)).createdAt() + verify(mockDelegate, times(1)).error() + verify(mockDelegate, times(1)).incompleteDetails() + verify(mockDelegate, times(1)).instructions() + verify(mockDelegate, times(1)).metadata() + verify(mockDelegate, times(1)).model() + verify(mockDelegate, times(1))._object_() + verify(mockDelegate, times(1))._output() // Indirect + verify(mockDelegate, times(1)).parallelToolCalls() + verify(mockDelegate, times(1)).temperature() + verify(mockDelegate, times(1)).toolChoice() + verify(mockDelegate, times(1)).tools() + verify(mockDelegate, times(1)).topP() + verify(mockDelegate, times(1)).maxOutputTokens() + verify(mockDelegate, times(1)).previousResponseId() + verify(mockDelegate, times(1)).reasoning() + verify(mockDelegate, times(1)).serviceTier() + verify(mockDelegate, times(1)).status() + verify(mockDelegate, times(1)).text() + verify(mockDelegate, times(1)).truncation() + verify(mockDelegate, times(1)).usage() + verify(mockDelegate, times(1)).user() + verifyNoMoreInteractions(mockDelegate) + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java b/openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java new file mode 100644 index 00000000..9ef89183 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java @@ -0,0 +1,73 @@ +package com.openai.example; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import java.util.List; + +public final class ResponsesStructuredOutputsExample { + + public static class Person { + @JsonPropertyDescription("The first name and surname of the person.") + public String name; + + public int birthYear; + + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; + + @Override + public String toString() { + return name + " (" + birthYear + '-' + deathYear + ')'; + } + } + + public static class Book { + public String title; + + public Person author; + + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + + public String genre; + + @JsonIgnore + public String isbn; + + @Override + public String toString() { + return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; + } + } + + public static class BookList { + public List books; + } + + private ResponsesStructuredOutputsExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + StructuredResponseCreateParams createParams = ResponseCreateParams.builder() + .input("List some famous late twentieth century novels.") + .text(BookList.class) + .model(ChatModel.GPT_4O) + .build(); + + client.responses().create(createParams).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(" - " + book)); + } +} From c2a9508650420d628b2ff27c303b8963edb611cb Mon Sep 17 00:00:00 2001 From: D Gardner Date: Wed, 14 May 2025 15:51:19 +0100 Subject: [PATCH 7/9] structured-outputs: removed support for Responses params Builder.body function --- .../models/responses/StructuredResponseCreateParams.kt | 4 +--- .../responses/StructuredResponseCreateParamsTest.kt | 10 ++++------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt index f2adad97..aae191ac 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt @@ -57,9 +57,7 @@ class StructuredResponseCreateParams( this.paramsBuilder = paramsBuilder } - // TODO: Probably not correct, as text config could be overwritten. - /** @see ResponseCreateParams.Builder.body */ - fun body(body: ResponseCreateParams.Body) = apply { paramsBuilder.body(body) } + // The `body(...)` function is deliberately not supported. /** @see ResponseCreateParams.Builder.input */ fun input(input: ResponseCreateParams.Input) = apply { paramsBuilder.input(input) } diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt index 1eb68ee9..3d205072 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt @@ -52,8 +52,6 @@ internal class StructuredResponseCreateParamsTest { private val RESPONSES_MODEL = ResponsesModel.ofChat(CHAT_MODEL) private val RESPONSES_ONLY_MODEL = ResponsesModel.ResponsesOnlyModel.O1_PRO private val PARAMS_INPUT = ResponseCreateParams.Input.ofText(STRING) - private val PARAMS_BODY = - ResponseCreateParams.Body.builder().input(PARAMS_INPUT).model(RESPONSES_MODEL).build() private val INCLUDABLE = ResponseIncludable.of(STRING) private val METADATA = ResponseCreateParams.Metadata.builder().build() @@ -87,7 +85,7 @@ internal class StructuredResponseCreateParamsTest { @JvmStatic private fun builderDelegationTestCases() = listOf( - DelegationWriteTestCase("body", PARAMS_BODY), + // The `body(...)` function is deliberately not supported: too messy. DelegationWriteTestCase("input", PARAMS_INPUT), DelegationWriteTestCase("input", JSON_FIELD), DelegationWriteTestCase("input", STRING), @@ -196,9 +194,9 @@ internal class StructuredResponseCreateParamsTest { @Test fun allBuilderDelegateFunctionsExistInDelegator() { - // The delegator class does not implement the various `text` functions of the delegate - // class. - checkAllDelegation(mockBuilderDelegate::class, builderDelegator::class, "text") + // The delegator class does not implement the various `text` functions or the `body` + // function of the delegate class. + checkAllDelegation(mockBuilderDelegate::class, builderDelegator::class, "body", "text") } @Test From 3edd9504e5be40422e7dc489923512ee5c3dd493 Mon Sep 17 00:00:00 2001 From: D Gardner Date: Wed, 14 May 2025 17:38:49 +0100 Subject: [PATCH 8/9] structured-outputs: extra docs and simpler code --- .../com/openai/core/StructuredOutputs.kt | 11 +--- .../models/responses/StructuredResponse.kt | 35 +---------- .../StructuredResponseOutputMessage.kt | 20 +------ .../services/blocking/ResponseService.kt | 14 ++++- .../blocking/chat/ChatCompletionService.kt | 16 ++++- .../StructuredResponseOutputMessageTest.kt | 21 ++----- .../responses/StructuredResponseTest.kt | 60 ++----------------- 7 files changed, 39 insertions(+), 138 deletions(-) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index 747995e1..df48c1af 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -77,16 +77,7 @@ internal fun textConfigFromClass( .format( ResponseFormatTextJsonSchemaConfig.builder() .name("json-schema-from-${type.simpleName}") - .schema( - ResponseFormatTextJsonSchemaConfig.Schema.builder() - .additionalProperties( - extractAndValidateSchema(type, localValidation) - .fields() - .asSequence() - .associate { it.key to JsonValue.fromJsonNode(it.value) } - ) - .build() - ) + .schema(JsonValue.fromJsonNode(extractAndValidateSchema(type, localValidation))) // Ensure the model's output strictly adheres to this JSON schema. This is the // essential "ON switch" for Structured Outputs. .strict(true) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt index d71450cf..f28865b0 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt @@ -164,43 +164,10 @@ class StructuredResponse( /** @see Response._additionalProperties */ fun _additionalProperties(): Map = rawResponse._additionalProperties() - private var validated: Boolean = false - /** @see Response.validate */ fun validate(): StructuredResponse = apply { - if (validated) { - return@apply - } - - id() - createdAt() - error().ifPresent { it.validate() } - incompleteDetails().ifPresent { it.validate() } - instructions() - metadata().ifPresent { it.validate() } - model().validate() - _object_().let { - if (it != JsonValue.from("response")) { - throw OpenAIInvalidDataException("'object_' is invalid, received $it") - } - } - // `output()` is a different type to that in the delegate class. output().forEach { it.validate() } - parallelToolCalls() - temperature() - toolChoice().validate() - tools().forEach { it.validate() } - topP() - maxOutputTokens() - previousResponseId() - reasoning().ifPresent { it.validate() } - serviceTier().ifPresent { it.validate() } - status().ifPresent { it.validate() } - text().ifPresent { it.validate() } - truncation().ifPresent { it.validate() } - usage().ifPresent { it.validate() } - user() - validated = true + rawResponse.validate() } /** @see Response.isValid */ diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt index e131df2c..b7083a25 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt @@ -52,29 +52,11 @@ class StructuredResponseOutputMessage( /** @see ResponseOutputMessage._additionalProperties */ fun _additionalProperties(): Map = rawMessage._additionalProperties() - private var validated: Boolean = false - /** @see ResponseOutputMessage.validate */ fun validate(): StructuredResponseOutputMessage = apply { - if (validated) { - return@apply - } - - id() // `content()` is a different type to that in the delegate class. content().forEach { it.validate() } - _role().let { - if (it != JsonValue.from("assistant")) { - throw OpenAIInvalidDataException("'role' is invalid, received $it") - } - } - status().validate() - _type().let { - if (it != JsonValue.from("message")) { - throw OpenAIInvalidDataException("'type' is invalid, received $it") - } - } - validated = true + rawMessage.validate() } /** @see ResponseOutputMessage.isValid */ diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt index 69570003..12ed80dd 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt @@ -44,11 +44,21 @@ interface ResponseService { requestOptions: RequestOptions = RequestOptions.none(), ): Response - /** @see create */ + /** + * Creates a model response. The model's structured output in JSON form will be deserialized + * automatically into an instance of the class `T`. See the SDK documentation for more details. + * + * @see create + */ fun create(params: StructuredResponseCreateParams): StructuredResponse = create(params, RequestOptions.none()) - /** @see create */ + /** + * Creates a model response. The model's structured output in JSON form will be deserialized + * automatically into an instance of the class `T`. See the SDK documentation for more details. + * + * @see create + */ fun create( params: StructuredResponseCreateParams, requestOptions: RequestOptions = RequestOptions.none(), diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 29792149..43f372d2 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -55,12 +55,24 @@ interface ChatCompletionService { requestOptions: RequestOptions = RequestOptions.none(), ): ChatCompletion - /** @see create */ + /** + * Creates a model response for the given chat conversation. The model's structured output in + * JSON form will be deserialized automatically into an instance of the class `T`. See the SDK + * documentation for more details. + * + * @see create + */ fun create( params: StructuredChatCompletionCreateParams ): StructuredChatCompletion = create(params, RequestOptions.none()) - /** @see create */ + /** + * Creates a model response for the given chat conversation. The model's structured output in + * JSON form will be deserialized automatically into an instance of the class `T`. See the SDK + * documentation for more details. + * + * @see create + */ fun create( params: StructuredChatCompletionCreateParams, requestOptions: RequestOptions = RequestOptions.none(), diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt index ee3c3a57..2b093cf5 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt @@ -4,7 +4,6 @@ import com.openai.core.DelegationReadTestCase import com.openai.core.JSON_FIELD import com.openai.core.JSON_VALUE import com.openai.core.JsonField -import com.openai.core.JsonValue import com.openai.core.MAP import com.openai.core.OPTIONAL import com.openai.core.STRING @@ -188,19 +187,13 @@ internal class StructuredResponseOutputMessageTest { @Test fun `delegation of validate`() { `when`(mockDelegate._content()).thenReturn(JsonField.of(listOf(CONTENT))) - `when`(mockDelegate._role()).thenReturn(JsonValue.from("assistant")) - `when`(mockDelegate.status()).thenReturn(ResponseOutputMessage.Status.COMPLETED) - `when`(mockDelegate._type()).thenReturn(JsonValue.from("message")) delegator.validate() - // Delegator's `validate()` does not call delegate's `validate()`. `_content` is called - // indirectly via the `content` field initializer. + // Delegator's `validate()` calls delegate's `validate()`. `_content` is called indirectly + // via the `content` field initializer. verify(mockDelegate, times(1))._content() - verify(mockDelegate, times(1)).id() - verify(mockDelegate, times(1))._role() - verify(mockDelegate, times(1)).status() - verify(mockDelegate, times(1))._type() + verify(mockDelegate, times(1)).validate() verifyNoMoreInteractions(mockDelegate) } @@ -208,17 +201,11 @@ internal class StructuredResponseOutputMessageTest { fun `delegation of isValid`() { // `isValid` calls `validate()`, so the test is similar to that for `validate()`. `when`(mockDelegate._content()).thenReturn(JsonField.of(listOf(CONTENT))) - `when`(mockDelegate._role()).thenReturn(JsonValue.from("assistant")) - `when`(mockDelegate.status()).thenReturn(ResponseOutputMessage.Status.COMPLETED) - `when`(mockDelegate._type()).thenReturn(JsonValue.from("message")) delegator.isValid() verify(mockDelegate, times(1))._content() - verify(mockDelegate, times(1)).id() - verify(mockDelegate, times(1))._role() - verify(mockDelegate, times(1)).status() - verify(mockDelegate, times(1))._type() + verify(mockDelegate, times(1)).validate() verifyNoMoreInteractions(mockDelegate) } diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt index 365476e7..9a1b76cd 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt @@ -6,7 +6,6 @@ import com.openai.core.DelegationReadTestCase import com.openai.core.JSON_FIELD import com.openai.core.JSON_VALUE import com.openai.core.JsonField -import com.openai.core.JsonValue import com.openai.core.LIST import com.openai.core.MAP import com.openai.core.OPTIONAL @@ -172,74 +171,27 @@ internal class StructuredResponseTest { @Test fun `delegation of validate`() { - `when`(mockDelegate.model()).thenReturn(RESPONSES_MODEL) - `when`(mockDelegate._object_()).thenReturn(JsonValue.from("response")) `when`(mockDelegate._output()).thenReturn(JsonField.of(listOf(OUTPUT_ITEM))) - `when`(mockDelegate.toolChoice()).thenReturn(TOOL_CHOICE) delegator.validate() - // Delegator's `validate()` does not call delegate's `validate()`. `_content` is called - // indirectly via the `content` field initializer. - verify(mockDelegate, times(1)).id() - verify(mockDelegate, times(1)).createdAt() - verify(mockDelegate, times(1)).error() - verify(mockDelegate, times(1)).incompleteDetails() - verify(mockDelegate, times(1)).instructions() - verify(mockDelegate, times(1)).metadata() - verify(mockDelegate, times(1)).model() - verify(mockDelegate, times(1))._object_() + // Delegator's `validate()` calls the delegate's `validate()`. Delegate's `_output()` is + // called indirectly via the `output` field initializer. verify(mockDelegate, times(1))._output() // Indirect - verify(mockDelegate, times(1)).parallelToolCalls() - verify(mockDelegate, times(1)).temperature() - verify(mockDelegate, times(1)).toolChoice() - verify(mockDelegate, times(1)).tools() - verify(mockDelegate, times(1)).topP() - verify(mockDelegate, times(1)).maxOutputTokens() - verify(mockDelegate, times(1)).previousResponseId() - verify(mockDelegate, times(1)).reasoning() - verify(mockDelegate, times(1)).serviceTier() - verify(mockDelegate, times(1)).status() - verify(mockDelegate, times(1)).text() - verify(mockDelegate, times(1)).truncation() - verify(mockDelegate, times(1)).usage() - verify(mockDelegate, times(1)).user() + verify(mockDelegate, times(1)).validate() verifyNoMoreInteractions(mockDelegate) } @Test fun `delegation of isValid`() { - `when`(mockDelegate.model()).thenReturn(RESPONSES_MODEL) - `when`(mockDelegate._object_()).thenReturn(JsonValue.from("response")) + // `isValid()` calls `validate()` which delegates to `validate()`, so the test is + // more-or-less the same as for `validate()`. `when`(mockDelegate._output()).thenReturn(JsonField.of(listOf(OUTPUT_ITEM))) - `when`(mockDelegate.toolChoice()).thenReturn(TOOL_CHOICE) - // `isValid()` calls `validate()`, so the test is more-or-less the same. delegator.isValid() - verify(mockDelegate, times(1)).id() - verify(mockDelegate, times(1)).createdAt() - verify(mockDelegate, times(1)).error() - verify(mockDelegate, times(1)).incompleteDetails() - verify(mockDelegate, times(1)).instructions() - verify(mockDelegate, times(1)).metadata() - verify(mockDelegate, times(1)).model() - verify(mockDelegate, times(1))._object_() verify(mockDelegate, times(1))._output() // Indirect - verify(mockDelegate, times(1)).parallelToolCalls() - verify(mockDelegate, times(1)).temperature() - verify(mockDelegate, times(1)).toolChoice() - verify(mockDelegate, times(1)).tools() - verify(mockDelegate, times(1)).topP() - verify(mockDelegate, times(1)).maxOutputTokens() - verify(mockDelegate, times(1)).previousResponseId() - verify(mockDelegate, times(1)).reasoning() - verify(mockDelegate, times(1)).serviceTier() - verify(mockDelegate, times(1)).status() - verify(mockDelegate, times(1)).text() - verify(mockDelegate, times(1)).truncation() - verify(mockDelegate, times(1)).usage() - verify(mockDelegate, times(1)).user() + verify(mockDelegate, times(1)).validate() verifyNoMoreInteractions(mockDelegate) } } From f4b259cb70c0b32616eaa6b652557d2c678db95a Mon Sep 17 00:00:00 2001 From: Tomer Aberbach Date: Wed, 14 May 2025 13:02:10 -0400 Subject: [PATCH 9/9] docs: swap primary structured outputs example --- README.md | 2 +- .../StructuredOutputsClassExample.java | 72 ------------------- .../example/StructuredOutputsExample.java | 67 ++++++++++++----- ... => StructuredOutputsRawAsyncExample.java} | 5 +- .../example/StructuredOutputsRawExample.java | 42 +++++++++++ 5 files changed, 93 insertions(+), 95 deletions(-) delete mode 100644 openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java rename openai-java-example/src/main/java/com/openai/example/{StructuredOutputsAsyncExample.java => StructuredOutputsRawAsyncExample.java} (91%) create mode 100644 openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java diff --git a/README.md b/README.md index 105aa4b9..dece1cd2 100644 --- a/README.md +++ b/README.md @@ -346,7 +346,7 @@ and setting it on the input parameters. However, for greater convenience, a JSON be derived automatically from the structure of an arbitrary Java class. The JSON content from the response will then be converted automatically to an instance of that Java class. A full, working example of the use of Structured Outputs with arbitrary Java classes can be seen in -[`StructuredOutputsClassExample`](openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java). +[`StructuredOutputsExample`](openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java). Java classes can contain fields declared to be instances of other classes and can use collections: diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java deleted file mode 100644 index 3f65a991..00000000 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java +++ /dev/null @@ -1,72 +0,0 @@ -package com.openai.example; - -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import com.openai.client.OpenAIClient; -import com.openai.client.okhttp.OpenAIOkHttpClient; -import com.openai.models.ChatModel; -import com.openai.models.chat.completions.ChatCompletionCreateParams; -import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; -import java.util.List; - -public final class StructuredOutputsClassExample { - - public static class Person { - @JsonPropertyDescription("The first name and surname of the person.") - public String name; - - public int birthYear; - - @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") - public String deathYear; - - @Override - public String toString() { - return name + " (" + birthYear + '-' + deathYear + ')'; - } - } - - public static class Book { - public String title; - - public Person author; - - @JsonPropertyDescription("The year in which the book was first published.") - public int publicationYear; - - public String genre; - - @JsonIgnore - public String isbn; - - @Override - public String toString() { - return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; - } - } - - public static class BookList { - public List books; - } - - private StructuredOutputsClassExample() {} - - public static void main(String[] args) { - // Configures using one of: - // - The `OPENAI_API_KEY` environment variable - // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables - OpenAIClient client = OpenAIOkHttpClient.fromEnv(); - - StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() - .model(ChatModel.GPT_4O_MINI) - .maxCompletionTokens(2048) - .responseFormat(BookList.class) - .addUserMessage("List some famous late twentieth century novels.") - .build(); - - client.chat().completions().create(createParams).choices().stream() - .flatMap(choice -> choice.message().content().stream()) - .flatMap(bookList -> bookList.books.stream()) - .forEach(book -> System.out.println(" - " + book)); - } -} diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java index 9d3ce9da..b4af9999 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java @@ -1,15 +1,54 @@ package com.openai.example; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; -import com.openai.core.JsonValue; import com.openai.models.ChatModel; -import com.openai.models.ResponseFormatJsonSchema; -import com.openai.models.ResponseFormatJsonSchema.JsonSchema; import com.openai.models.chat.completions.ChatCompletionCreateParams; -import java.util.Map; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; +import java.util.List; public final class StructuredOutputsExample { + + public static class Person { + @JsonPropertyDescription("The first name and surname of the person.") + public String name; + + public int birthYear; + + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; + + @Override + public String toString() { + return name + " (" + birthYear + '-' + deathYear + ')'; + } + } + + public static class Book { + public String title; + + public Person author; + + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + + public String genre; + + @JsonIgnore + public String isbn; + + @Override + public String toString() { + return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; + } + } + + public static class BookList { + public List books; + } + private StructuredOutputsExample() {} public static void main(String[] args) { @@ -18,26 +57,16 @@ public static void main(String[] args) { // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClient client = OpenAIOkHttpClient.fromEnv(); - // TODO: Update this once we support extracting JSON schemas from Java classes - JsonSchema.Schema schema = JsonSchema.Schema.builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", JsonValue.from(Map.of("employees", Map.of("items", Map.of("type", "string"))))) - .build(); - ChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() .model(ChatModel.GPT_4O_MINI) .maxCompletionTokens(2048) - .responseFormat(ResponseFormatJsonSchema.builder() - .jsonSchema(JsonSchema.builder() - .name("employee-list") - .schema(schema) - .build()) - .build()) - .addUserMessage("Who works at OpenAI?") + .responseFormat(BookList.class) + .addUserMessage("List some famous late twentieth century novels.") .build(); client.chat().completions().create(createParams).choices().stream() .flatMap(choice -> choice.message().content().stream()) - .forEach(System.out::println); + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(" - " + book)); } } diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsAsyncExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawAsyncExample.java similarity index 91% rename from openai-java-example/src/main/java/com/openai/example/StructuredOutputsAsyncExample.java rename to openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawAsyncExample.java index a645f6cb..f726e0cf 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsAsyncExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawAsyncExample.java @@ -9,8 +9,8 @@ import com.openai.models.chat.completions.ChatCompletionCreateParams; import java.util.Map; -public final class StructuredOutputsAsyncExample { - private StructuredOutputsAsyncExample() {} +public final class StructuredOutputsRawAsyncExample { + private StructuredOutputsRawAsyncExample() {} public static void main(String[] args) { // Configures using one of: @@ -18,7 +18,6 @@ public static void main(String[] args) { // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClientAsync client = OpenAIOkHttpClientAsync.fromEnv(); - // TODO: Update this once we support extracting JSON schemas from Java classes JsonSchema.Schema schema = JsonSchema.Schema.builder() .putAdditionalProperty("type", JsonValue.from("object")) .putAdditionalProperty( diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java new file mode 100644 index 00000000..3b9ff03a --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java @@ -0,0 +1,42 @@ +package com.openai.example; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonValue; +import com.openai.models.ChatModel; +import com.openai.models.ResponseFormatJsonSchema; +import com.openai.models.ResponseFormatJsonSchema.JsonSchema; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import java.util.Map; + +public final class StructuredOutputsRawExample { + private StructuredOutputsRawExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + JsonSchema.Schema schema = JsonSchema.Schema.builder() + .putAdditionalProperty("type", JsonValue.from("object")) + .putAdditionalProperty( + "properties", JsonValue.from(Map.of("employees", Map.of("items", Map.of("type", "string"))))) + .build(); + ChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .maxCompletionTokens(2048) + .responseFormat(ResponseFormatJsonSchema.builder() + .jsonSchema(JsonSchema.builder() + .name("employee-list") + .schema(schema) + .build()) + .build()) + .addUserMessage("Who works at OpenAI?") + .build(); + + client.chat().completions().create(createParams).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .forEach(System.out::println); + } +}