diff --git a/README.md b/README.md index 39125ce8..dece1cd2 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() The SDK provides conveniences for streamed chat completions. A [`ChatCompletionAccumulator`](openai-java-core/src/main/kotlin/com/openai/helpers/ChatCompletionAccumulator.kt) -can record the stream of chat completion chunks in the response as they are processed and accumulate +can record the stream of chat completion chunks in the response as they are processed and accumulate a [`ChatCompletion`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletion.kt) object similar to that which would have been returned by the non-streaming API. @@ -334,6 +334,205 @@ client.chat() ChatCompletion chatCompletion = chatCompletionAccumulator.chatCompletion(); ``` +## Structured outputs with JSON schemas + +Open AI [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat) +is a feature that ensures that the model will always generate responses that adhere to a supplied +[JSON schema](https://json-schema.org/overview/what-is-jsonschema). + +A JSON schema can be defined by creating a +[`ResponseFormatJsonSchema`](openai-java-core/src/main/kotlin/com/openai/models/ResponseFormatJsonSchema.kt) +and setting it on the input parameters. However, for greater convenience, a JSON schema can instead +be derived automatically from the structure of an arbitrary Java class. The JSON content from the +response will then be converted automatically to an instance of that Java class. A full, working +example of the use of Structured Outputs with arbitrary Java classes can be seen in +[`StructuredOutputsExample`](openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java). + +Java classes can contain fields declared to be instances of other classes and can use collections: + +```java +class Person { + public String name; + public int birthYear; +} + +class Book { + public String title; + public Person author; + public int publicationYear; +} + +class BookList { + public List books; +} +``` + +Pass the top-level class—`BookList` in this example—to `responseFormat(Class)` when building the +parameters and then access an instance of `BookList` from the generated message content in the +response: + +```java +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addUserMessage("List some famous late twentieth century novels.") + .model(ChatModel.GPT_4_1) + .responseFormat(BookList.class) + .build(); + +client.chat().completions().create(params).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(book.title + " by " + book.author.name)); +``` + +You can start building the parameters with an instance of +[`ChatCompletionCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt) +or +[`StructuredChatCompletionCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt). +If you start with the former (which allows for more compact code) the builder type will change to +the latter when `ChatCompletionCreateParams.Builder.responseFormat(Class)` is called. + +If a field in a class is optional and does not require a defined value, you can represent this using +the [`java.util.Optional`](https://docs.oracle.com/javase/8/docs/api/java/util/Optional.html) class. +It is up to the AI model to decide whether to provide a value for that field or leave it empty. + +```java +import java.util.Optional; + +class Book { + public String title; + public Person author; + public int publicationYear; + public Optional isbn; +} +``` + +Generic type information for fields is retained in the class's metadata, but _generic type erasure_ +applies in other scopes. While, for example, a JSON schema defining an array of books can be derived +from the `BookList.books` field with type `List`, a valid JSON schema cannot be derived from a +local variable of that same type, so the following will _not_ work: + +```java +List books = new ArrayList<>(); + +StructuredChatCompletionCreateParams> params = ChatCompletionCreateParams.builder() + .responseFormat(books.getClass()) + // ... + .build(); +``` + +If an error occurs while converting a JSON response to an instance of a Java class, the error +message will include the JSON response to assist in diagnosis. For instance, if the response is +truncated, the JSON data will be incomplete and cannot be converted to a class instance. If your +JSON response may contain sensitive information, avoid logging it directly, or ensure that you +redact any sensitive details from the error message. + +### Local JSON schema validation + +Structured Outputs supports a +[subset](https://platform.openai.com/docs/guides/structured-outputs#supported-schemas) of the JSON +Schema language. Schemas are generated automatically from classes to align with this subset. +However, due to the inherent structure of the classes, the generated schema may still violate +certain OpenAI schema restrictions, such as exceeding the maximum nesting depth or utilizing +unsupported data types. + +To facilitate compliance, the method `responseFormat(Class)` performs a validation check on the +schema derived from the specified class. This validation ensures that all restrictions are adhered +to. If any issues are detected, an exception will be thrown, providing a detailed message outlining +the reasons for the validation failure. + +- **Local Validation**: The validation process occurs locally, meaning no requests are sent to the +remote AI model. If the schema passes local validation, it is likely to pass remote validation as +well. +- **Remote Validation**: The remote AI model will conduct its own validation upon receiving the JSON +schema in the request. +- **Version Compatibility**: There may be instances where local validation fails while remote +validation succeeds. This can occur if the SDK version is outdated compared to the restrictions +enforced by the remote AI model. +- **Disabling Local Validation**: If you encounter compatibility issues and wish to bypass local +validation, you can disable it by passing +[`JsonSchemaLocalValidation.NO`](openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt) +to the `responseFormat(Class, JsonSchemaLocalValidation)` method when building the parameters. +(The default value for this parameter is `JsonSchemaLocalValidation.YES`.) + +```java +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addUserMessage("List some famous late twentieth century novels.") + .model(ChatModel.GPT_4_1) + .responseFormat(BookList.class, JsonSchemaLocalValidation.NO) + .build(); +``` + +By following these guidelines, you can ensure that your structured outputs conform to the necessary +schema requirements and minimize the risk of remote validation errors. + +### Usage with the Responses API + +_Structured Outputs_ are also supported for the Responses API. The usage is the same as described +except where the Responses API differs slightly from the Chat Completions API. Pass the top-level +class to `text(Class)` when building the parameters and then access an instance of the class from +the generated message content in the response. + +You can start building the parameters with an instance of +[`ResponseCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt) +or +[`StructuredResponseCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt). +If you start with the former (which allows for more compact code) the builder type will change to +the latter when `ResponseCreateParams.Builder.text(Class)` is called. + +For a full example of the usage of _Structured Outputs_ with the Responses API, see +[`ResponsesStructuredOutputsExample`](openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java). + +### Annotating classes and JSON schemas + +You can use annotations to add further information to the JSON schema derived from your Java +classes, or to exclude individual fields from the schema. Details from annotations captured in the +JSON schema may be used by the AI model to improve its response. The SDK supports the use of +[Jackson Databind](https://github.com/FasterXML/jackson-databind) annotations. + +```java +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +class Person { + @JsonPropertyDescription("The first name and surname of the person") + public String name; + public int birthYear; + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; +} + +@JsonClassDescription("The details of one published book") +class Book { + public String title; + public Person author; + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + @JsonIgnore public String genre; +} + +class BookList { + public List books; +} +``` + +- Use `@JsonClassDescription` to add a detailed description to a class. +- Use `@JsonPropertyDescription` to add a detailed description to a field of a class. +- Use `@JsonIgnore` to omit a field of a class from the generated JSON schema. + +If you use `@JsonProperty(required = false)`, the `false` value will be ignored. OpenAI JSON schemas +must mark all properties as _required_, so the schema generated from your Java classes will respect +that restriction and ignore any annotation that would violate it. + ## File uploads The SDK defines methods that accept files. @@ -652,7 +851,7 @@ If the SDK threw an exception, but you're _certain_ the version is compatible, t ## Microsoft Azure -To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the same +To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the same OpenAI client builder but with the Azure-specific configuration. ```java @@ -665,7 +864,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() .build(); ``` -See the complete Azure OpenAI example in the [`openai-java-example`](openai-java-example/src/main/java/com/openai/example/AzureEntraIdExample.java) directory. The other examples in the directory also work with Azure as long as the client is configured to use it. +See the complete Azure OpenAI example in the [`openai-java-example`](openai-java-example/src/main/java/com/openai/example/AzureEntraIdExample.java) directory. The other examples in the directory also work with Azure as long as the client is configured to use it. ## Network options diff --git a/openai-java-core/build.gradle.kts b/openai-java-core/build.gradle.kts index 08a91e0d..894f0e23 100644 --- a/openai-java-core/build.gradle.kts +++ b/openai-java-core/build.gradle.kts @@ -27,6 +27,8 @@ dependencies { implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.18.2") implementation("org.apache.httpcomponents.core5:httpcore5:5.2.4") implementation("org.apache.httpcomponents.client5:httpclient5:5.3.1") + implementation("com.github.victools:jsonschema-generator:4.38.0") + implementation("com.github.victools:jsonschema-module-jackson:4.38.0") testImplementation(kotlin("test")) testImplementation(project(":openai-java-client-okhttp")) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt new file mode 100644 index 00000000..9a3ae799 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt @@ -0,0 +1,19 @@ +package com.openai.core + +/** + * Options for local validation of JSON schemas derived from arbitrary classes before a request is + * executed. + */ +enum class JsonSchemaLocalValidation { + /** + * Validate the JSON schema locally before the request is executed. The remote AI model will + * also validate the JSON schema. + */ + YES, + + /** + * Do not validate the JSON schema locally before the request is executed. The remote AI model + * will always validate the JSON schema. + */ + NO, +} diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt new file mode 100644 index 00000000..85c20b43 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt @@ -0,0 +1,668 @@ +package com.openai.core + +import com.fasterxml.jackson.databind.JsonNode +import com.openai.core.JsonSchemaValidator.Companion.MAX_ENUM_TOTAL_STRING_LENGTH +import com.openai.core.JsonSchemaValidator.Companion.UNRESTRICTED_ENUM_VALUES_LIMIT + +/** + * A validator that ensures that a JSON schema complies with the rules and restrictions imposed by + * the OpenAI API specification for the input schemas used to define structured outputs. Only a + * subset of the JSON Schema language is supported. The purpose of this validator is to perform a + * quick check of a schema so that it can be determined to be likely to be accepted when passed in + * the request for an AI inference. + * + * This validator assumes that the JSON schema represents the structure of Java/Kotlin classes; it + * is not a general-purpose JSON schema validator. Assumptions are also made that the generator will + * be well-behaved, so the validation is not a check for strict conformance to the JSON Schema + * specification, but to the OpenAI API specification's restrictions on JSON schemas. + */ +internal class JsonSchemaValidator private constructor() { + + companion object { + // The names of the supported schema keywords. All other keywords will be rejected. + private const val SCHEMA = "\$schema" + private const val ID = "\$id" + private const val DEFS = "\$defs" + private const val REF = "\$ref" + private const val PROPS = "properties" + private const val ANY_OF = "anyOf" + private const val TYPE = "type" + private const val REQUIRED = "required" + private const val DESC = "description" + private const val TITLE = "title" + private const val ITEMS = "items" + private const val CONST = "const" + private const val ENUM = "enum" + private const val ADDITIONAL_PROPS = "additionalProperties" + + // The names of the supported schema data types. + // + // JSON Schema does not define an "integer" type, only a "number" type, but it allows any + // schema to define its own "vocabulary" of type names. "integer" is supported by OpenAI. + private const val TYPE_ARRAY = "array" + private const val TYPE_OBJECT = "object" + private const val TYPE_BOOLEAN = "boolean" + private const val TYPE_STRING = "string" + private const val TYPE_NUMBER = "number" + private const val TYPE_INTEGER = "integer" + private const val TYPE_NULL = "null" + + // The validator checks that unsupported type-specific keywords are not present in a + // property node. The OpenAI API specification states: + // + // "Notable keywords not supported include: + // + // - For strings: `minLength`, `maxLength`, `pattern`, `format` + // - For numbers: `minimum`, `maximum`, `multipleOf` + // - For objects: `patternProperties`, `unevaluatedProperties`, `propertyNames`, + // `minProperties`, `maxProperties` + // - For arrays: `unevaluatedItems`, `contains`, `minContains`, `maxContains`, `minItems`, + // `maxItems`, `uniqueItems`" + // + // As that list is not exhaustive, and no keywords are explicitly named as supported, this + // validation allows _no_ type-specific keywords. The following sets define the allowed + // keywords in different contexts and all others are rejected. + + /** + * The set of allowed keywords in the root schema only, not including the keywords that are + * also allowed in a sub-schema. + */ + private val ALLOWED_KEYWORDS_ROOT_SCHEMA_ONLY = setOf(SCHEMA, ID, DEFS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"anyOf"` field is + * present. OpenAI allows the `"anyOf"` field in sub-schemas, but not in the root schema. + */ + private val ALLOWED_KEYWORDS_ANY_OF_SUB_SCHEMA = setOf(ANY_OF, TITLE, DESC) + + /** + * The set of allowed keywords when defining sub-schemas when the `"$ref"` field is present. + */ + private val ALLOWED_KEYWORDS_REF_SUB_SCHEMA = setOf(REF, TITLE, DESC) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"object"`. + */ + private val ALLOWED_KEYWORDS_OBJECT_SUB_SCHEMA = + setOf(TYPE, REQUIRED, ADDITIONAL_PROPS, TITLE, DESC, PROPS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"array"`. + */ + private val ALLOWED_KEYWORDS_ARRAY_SUB_SCHEMA = setOf(TYPE, TITLE, DESC, ITEMS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"boolean"`, `"integer"`, `"number"`, or `"string"`. + */ + private val ALLOWED_KEYWORDS_SIMPLE_SUB_SCHEMA = setOf(TYPE, TITLE, DESC, ENUM, CONST) + + /** + * The maximum total length of all strings used in the schema for property names, definition + * names, enum values and const values. The OpenAI specification states: + * > In a schema, total string length of all property names, definition names, enum values, + * > and const values cannot exceed 15,000 characters. + */ + private const val MAX_TOTAL_STRING_LENGTH = 15_000 + + /** The maximum number of object properties allowed in a schema. */ + private const val MAX_OBJECT_PROPERTIES = 100 + + /** The maximum number of enum values across all enums in the schema. */ + private const val MAX_ENUM_VALUES = 500 + + /** + * The number of enum values in any one enum with string values beyond which a limit of + * [MAX_ENUM_TOTAL_STRING_LENGTH] is imposed on the total length of all the string values of + * that one enum. + */ + private const val UNRESTRICTED_ENUM_VALUES_LIMIT = 250 + + /** + * The maximum total length of all string values of a single enum where the number of values + * exceeds [UNRESTRICTED_ENUM_VALUES_LIMIT]. + */ + private const val MAX_ENUM_TOTAL_STRING_LENGTH = 7_500 + + /** The maximum depth (number of levels) of nesting allowed in a schema. */ + private const val MAX_NESTING_DEPTH = 5 + + /** The depth value that corresponds to the root level of the schema. */ + private const val ROOT_DEPTH = 0 + + /** + * The path string that identifies the root node in the schema when appearing in error + * messages or references. + */ + private const val ROOT_PATH = "#" + + /** + * Creates a new [JsonSchemaValidator]. After calling [validate], the validator instance + * holds information about the errors that occurred during validation (if any). A validator + * instance can be used only once to validate a schema; to validate another schema, create + * another validator. + */ + fun create() = JsonSchemaValidator() + } + + /** + * The total length of all strings used in the schema for property names, definition names, enum + * values and const values. + */ + private var totalStringLength: Int = 0 + + /** The total number of values across all enums in the schema. */ + private var totalEnumValues: Int = 0 + + /** The total number of object properties found in the schema, including in definitions. */ + private var totalObjectProperties: Int = 0 + + /** + * The set of valid references that may appear in the schema. This set includes the root schema + * and any definitions within the root schema. This is used to verify that references elsewhere + * in the schema are valid. This will always contain the root schema, but that may be the only + * member. + */ + private var validReferences: MutableSet = mutableSetOf(ROOT_PATH) + + /** The list of error messages accumulated during the validation process. */ + private val errors: MutableList = mutableListOf() + + /** + * Indicates if this validator has validated a schema or not. If a schema has been validated, + * this validator cannot be used again. + */ + private var isValidationComplete = false + + /** + * Gets the list of errors that were recorded during the validation pass. + * + * @return The list of errors. The list may be empty if no errors were recorded. In that case, + * the schema was found to be valid, or has not yet been validated by calling [validate]. + */ + fun errors(): List = errors.toImmutable() + + /** + * Indicates if a validated schema is valid or not. + * + * @return `true` if a schema has been validated by calling [validate] and no errors were + * reported; or `false` if errors were reported or if a schema has not yet been validated. + */ + fun isValid(): Boolean = isValidationComplete && errors.isEmpty() + + /** + * Validates a schema with respect to the OpenAI API specifications. + * + * @param rootSchema The root node of the tree representing the JSON schema definition. + * @return This schema validator for convenience, such as when chaining calls. + * @throws IllegalStateException If called a second time. Create a new validator to validate + * each new schema. + */ + fun validate(rootSchema: JsonNode): JsonSchemaValidator { + check(!isValidationComplete) { "Validation already complete." } + isValidationComplete = true + + validateSchema(rootSchema, ROOT_PATH, ROOT_DEPTH) + + // Verify total counts/lengths. These are not localized to a specific element in the schema, + // as no one element is the cause of the error; it is the combination of all elements that + // exceed the limits. Therefore, the root path is used in the error messages. + verify(totalEnumValues <= MAX_ENUM_VALUES, ROOT_PATH) { + "Total number of enum values ($totalEnumValues) exceeds limit of $MAX_ENUM_VALUES." + } + verify(totalStringLength <= MAX_TOTAL_STRING_LENGTH, ROOT_PATH) { + "Total string length of all values ($totalStringLength) exceeds " + + "limit of $MAX_TOTAL_STRING_LENGTH." + } + verify(totalObjectProperties <= MAX_OBJECT_PROPERTIES, ROOT_PATH) { + "Total number of object properties ($totalObjectProperties) exceeds " + + "limit of $MAX_OBJECT_PROPERTIES." + } + + return this + } + + /** + * Validates a schema. This may be the root schema or a sub-schema. Some validations are + * specific to the root schema, which is identified by the [depth] being equal to zero. + * + * This method is recursive: it will validate the given schema and any sub-schemas that it + * contains at any depth. References to other schemas (either the root schema or definition + * sub-schemas) do not increase the depth of nesting, as those references are not followed + * recursively, only checked to be valid internal schema references. + * + * @param schema The schema to be validated. This may be the root schema or any sub-schema. + * @param path The path that identifies the location of this schema within the JSON schema. For + * example, for the root schema, this will be `"#"`; for a definition sub-schema of a `Person` + * object, this will be `"#/$defs/Person"`. + * @param depth The current depth of nesting. The OpenAI API specification places a maximum + * limit on the depth of nesting, which will result in an error if it is exceeded. The nesting + * depth increases with each recursion into a nested sub-schema. For the root schema, the + * nesting depth is zero; all other sub-schemas will have a nesting depth greater than zero. + */ + private fun validateSchema(schema: JsonNode, path: String, depth: Int) { + verify(depth <= MAX_NESTING_DEPTH, path) { + "Current nesting depth is $depth, but maximum is $MAX_NESTING_DEPTH." + } + + verify(schema.isObject, path, { "Schema or sub-schema is not an object." }) { + // If the schema is not an object, perform no further validations. + return + } + + verify(!schema.isEmpty, path) { "Schema or sub-schema is empty." } + + if (depth == ROOT_DEPTH) { + // Sanity check for the presence of the "$schema" field, as this makes it more likely + // that the schema with `depth == 0` is actually the root node of a JSON schema, not + // just a generic JSON document that is being validated in error. + verify(schema.get(SCHEMA) != null, path) { "Root schema missing '$SCHEMA' field." } + } + + // Before sub-schemas can be validated, the list of definitions must be recorded to ensure + // that "$ref" references can be checked for validity. Definitions are optional and only + // appear in the root schema. + validateDefinitions(schema.get(DEFS), "$path/$DEFS", depth) + + val anyOf = schema.get(ANY_OF) + val type = schema.get(TYPE) + val ref = schema.get(REF) + + verify( + (anyOf != null).xor(type != null).xor(ref != null), + path, + { "Expected exactly one of '$TYPE' or '$ANY_OF' or '$REF'." }, + ) { + // Validation cannot continue if none are set, or if more than one is set. + return + } + + validateAnyOfSchema(schema, path, depth) + validateTypeSchema(schema, path, depth) + validateRefSchema(schema, path, depth) + } + + /** + * Validates a schema if it has an `"anyOf"` field. OpenAI does not support the use of `"anyOf"` + * at the root of a JSON schema. The value is the field is expected to be an array of valid + * sub-schemas. If the schema has no `"anyOf"` field, no action is taken. + */ + private fun validateAnyOfSchema(schema: JsonNode, path: String, depth: Int) { + val anyOf = schema.get(ANY_OF) + + if (anyOf == null) return + + validateKeywords(schema, ALLOWED_KEYWORDS_ANY_OF_SUB_SCHEMA, path, depth) + + verify( + anyOf.isArray && !anyOf.isEmpty, + path, + { "'$ANY_OF' field is not a non-empty array." }, + ) { + return + } + + // Validates that the root schema does not contain an `anyOf` field. This is a restriction + // imposed by the OpenAI API specification. `anyOf` fields _can_ appear at other depths. + verify(depth != ROOT_DEPTH, path) { "Root schema contains '$ANY_OF' field." } + + // Each entry must be a valid sub-schema. + anyOf.forEachIndexed { index, subSchema -> + validateSchema(subSchema, "$path/$ANY_OF[$index]", depth + 1) + } + } + + /** + * Validates a schema if it has a `"$ref"` field. The reference is checked to ensure it + * corresponds to a valid definition, or is a reference to the root schema. Recursive references + * are allowed. If no `"$ref"` field is found in the schema, no action is taken. + */ + private fun validateRefSchema(schema: JsonNode, path: String, depth: Int) { + val ref = schema.get(REF) + + if (ref == null) return + + validateKeywords(schema, ALLOWED_KEYWORDS_REF_SUB_SCHEMA, path, depth) + + val refPath = "$path/$REF" + + verify(ref.isTextual, refPath, { "'$REF' field is not a text value." }) { + // No point checking the reference has a referent if it is definitely malformed. + return + } + verify(ref.asText() in validReferences, refPath) { + "Invalid or unsupported reference: '${ref.asText()}'." + } + } + + /** + * Validates a schema if it has a `"type"` field. This includes most sub-schemas, except those + * that have a `"$ref"` or `"anyOf"` field instead. The `"type"` field may be set to a text + * value that is the name of the type (e.g., `"object"`, `"array"`, `"number"`), or it may be + * set to an array that contains two text values: the name of the type and `"null"`. The OpenAI + * API specification explains that this is how a property can be both required (i.e., it must + * appear in the JSON document), but its value can be optional (i.e., it can be set explicitly + * to `"null"`). If the schema has no `"type"` field, no action is taken. + */ + private fun validateTypeSchema(schema: JsonNode, path: String, depth: Int) { + val type = schema.get(TYPE) + + if (type == null) return + + val typeName = + if (type.isTextual) { + // Type will be something like `"type" : "string"` + type.asText() + } else if (type.isArray) { + // Type will be something like `"type" : [ "string", "null" ]`. This corresponds to + // the use of "Optional" in Java/Kotlin. + getTypeNameFromTypeArray(type, "$path/$TYPE") + } else { + error(path) { "'$TYPE' field is not a type name or array of type names." } + return + } + + when (typeName) { + TYPE_ARRAY -> validateArraySchema(schema, path, depth) + TYPE_OBJECT -> validateObjectSchema(schema, path, depth) + + TYPE_BOOLEAN, + TYPE_INTEGER, + TYPE_NUMBER, + TYPE_STRING -> validateSimpleSchema(schema, typeName, path, depth) + + // The type name could not be determined from a type name array. An error will already + // have been logged by `getTypeNameFromTypeArray`, so no need to do anything more here. + null -> return + + else -> error("$path/$TYPE") { "Unsupported '$TYPE' value: '$typeName'." } + } + } + + /** + * Validates a schema whose `"type"` is `"object"`. It is the responsibility of the caller to + * ensure that [schema] contains that type definition. If no type, or a different type is + * present, the behavior is not defined. + */ + private fun validateObjectSchema(schema: JsonNode, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_OBJECT_SUB_SCHEMA, path, depth) + + // The schema must declare that additional properties are not allowed. For this check, it + // does not matter if there are no "properties" in the schema. + verify( + schema.get(ADDITIONAL_PROPS) != null && + schema.get(ADDITIONAL_PROPS).asBoolean() == false, + path, + ) { + "'$ADDITIONAL_PROPS' field is missing or is not set to 'false'." + } + + val properties = schema.get(PROPS) + + // The "properties" field may be missing (there may be no properties to declare), but if it + // is present, it must be a non-empty object, or validation cannot continue. + // TODO: Decide if a missing or empty "properties" field is OK or not. + verify( + properties == null || (properties.isObject && !properties.isEmpty), + path, + { "'$PROPS' field is not a non-empty object." }, + ) { + return + } + + if (properties != null) { // Must be an object. + // If a "properties" field is present, there must also be a "required" field. All + // properties must be named in the list of required properties. + validatePropertiesRequired( + properties.fieldNames().asSequence().toSet(), + schema.get(REQUIRED), + "$path/$REQUIRED", + ) + validateProperties(properties, "$path/$PROPS", depth) + } + } + + /** + * Validates a schema whose `"type"` is `"array"`. It is the responsibility of the caller to + * ensure that [schema] contains that type definition. If no type, or a different type is + * present, the behavior is not defined. + * + * An array schema must have an `"items"` field whose value is an object representing a valid + * sub-schema. + */ + private fun validateArraySchema(schema: JsonNode, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_ARRAY_SUB_SCHEMA, path, depth) + + val items = schema.get(ITEMS) + + verify( + items != null && items.isObject, + path, + { "'$ITEMS' field is missing or is not an object." }, + ) { + return + } + + validateSchema(items, "$path/$ITEMS", depth + 1) + } + + /** + * Validates a schema whose `"type"` is one of the supported simple type names other than + * `"object"` and `"array"`. It is the responsibility of the caller to ensure that [schema] + * contains the correct type definition. If no type, or a different type is present, the + * behavior is not defined. + * + * @param typeName The name of the specific type of the schema. Where the field value is + * optional and the type is defined as an array of a type name and a `"null"`, this is the + * value of the non-`"null"` type name. For example `"string"`, or `"number"`. + */ + private fun validateSimpleSchema(schema: JsonNode, typeName: String, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_SIMPLE_SUB_SCHEMA, path, depth) + + val enumField = schema.get(ENUM) + + // OpenAI API specification: "For a single enum property with string values, the total + // string length of all enum values cannot exceed 7,500 characters when there are more than + // 250 enum values." + val isString = typeName == TYPE_STRING + var numEnumValues = 0 + var stringLength = 0 + + enumField?.forEach { value -> + // OpenAI places limits on the total string length of all enum values across all enums + // without being specific about the type of those enums (unlike for enums with string + // values, which have their own restrictions noted above). The specification does not + // indicate how to count the string length for boolean or number values. Here it is + // assumed that their simple string representations should be counted. + val length = value.asText().length + + totalStringLength += length + totalEnumValues++ + + if (isString) { + numEnumValues++ + stringLength += length + } + } + + verify( + !isString || + numEnumValues <= UNRESTRICTED_ENUM_VALUES_LIMIT || + stringLength <= MAX_ENUM_TOTAL_STRING_LENGTH, + "$path/$ENUM", + ) { + "Total string length ($stringLength) of values of an enum with $numEnumValues " + + "values exceeds limit of $MAX_ENUM_TOTAL_STRING_LENGTH." + } + + schema.get(CONST)?.let { constValue -> totalStringLength += constValue.asText().length } + } + + /** + * Validates that the definitions (if present) contain fields that each define a valid schema. + * Records the names of any definitions to construct the set of possible valid references to + * those definitions. This set will be used to validate any references from within definition + * sub-schemas, or any other sub-schemas validated at a later time. + * + * @param defs The node containing the definitions. Definitions are optional, so this node may + * be `null`. Definitions may appear in the root schema, but will not appear in any + * sub-schemas. If no definitions are present, the list of valid references will not be + * changed and no errors will be recorded. + * @param path The path that identifies the location within the schema of the `"$defs"` node. + * @param depth The current depth of nesting. If definitions are present, this will be zero, as + * that is the depth of the root schema. + */ + private fun validateDefinitions(defs: JsonNode?, path: String, depth: Int) { + // Definitions are optional. If present, expect an object whose fields are named from the + // classes the definitions were extracted from. If not present, do not continue. + verify(defs == null || defs.isObject, path, { "'$DEFS' must be an object." }) { + return + } + + // First, record the valid references to definitions, as any definition sub-schema may + // contain a reference to any other definitions sub-schema (including itself) and those + // references need to be validated. + defs?.fieldNames()?.asSequence()?.forEach { defName -> + val reference = "$path/$defName" + + // Consider that there might be duplicate definition names if two different classes + // (from different packages) have the same simple name. That would be an error, but + // there is no need to stop the validations. + // TODO: How should duplicate names be handled? Will the generator use longer names? + verify(reference !in validReferences, path) { "Duplicate definition of '$defName'." } + validReferences += reference + } + + // Second, recursively validate the definition sub-schemas. + defs?.fieldNames()?.asSequence()?.forEach { defName -> + totalStringLength += defName.length + validateSchema(defs.get(defName), "$path/$DEFS/$defName", depth + 1) + } + } + + /** + * Validates that every property in a collection of property names appears in the array of + * property names in a `"required"` field. + * + * @param propertyNames The collection of property names to check in the array of required + * properties. This collection will not be empty. + * @param required The `"required"` field. This is expected to be a non-`null` array field with + * a set of property names. + * @param path The path identifying the location of the `"required"` field within the schema. + */ + private fun validatePropertiesRequired( + propertyNames: Collection, + required: JsonNode?, + path: String, + ) { + val requiredNames = required?.map { it.asText() }?.toSet() ?: emptySet() + + propertyNames.forEach { propertyName -> + verify(propertyName in requiredNames, path) { + "'$PROPS' field '$propertyName' is not listed as '$REQUIRED'." + } + } + } + + /** + * Validates that each named entry in the `"properties"` field of an object schema has a value + * that is a valid sub-schema. + * + * @param properties The `"properties"` field node of an object schema. + * @param path The path identifying the location of the `"properties"` field within the schema. + */ + private fun validateProperties(properties: JsonNode, path: String, depth: Int) { + val propertyNames = properties.fieldNames().asSequence().toList() + + propertyNames.forEach { propertyName -> + totalObjectProperties++ + totalStringLength += propertyName.length + validateSchema(properties.get(propertyName), "$path/$propertyName", depth + 1) + } + } + + /** + * Validates that the names of all fields in the given schema node are present in a collection + * of allowed keywords. + * + * @param depth The nesting depth of the [schema] node. If this depth is zero, an additional set + * of allowed keywords will be included automatically for keywords that are allowed to appear + * only at the root level of the schema (e.g., `"$schema"`, `"$defs"`). + */ + private fun validateKeywords( + schema: JsonNode, + allowedKeywords: Collection, + path: String, + depth: Int, + ) { + schema.fieldNames().forEach { keyword -> + verify( + keyword in allowedKeywords || + (depth == ROOT_DEPTH && keyword in ALLOWED_KEYWORDS_ROOT_SCHEMA_ONLY), + path, + ) { + "Use of '$keyword' is not supported here." + } + } + } + + /** + * Gets the name of a type from the given `"type"` field, where the field is an array that + * contains exactly two string values: a type name and a `"null"` (in any order). + * + * @param type The type node. This must be a field with an array value. If this is not an array + * field, the behavior is undefined. It is the responsibility of the caller to ensure that + * this function is only called for array fields. + * @return The type name in the array that is not the `"null"` type; or `null` if no such type + * name was found, or if the array does not contain exactly two expected values: the type name + * and a `"null"` type. If `null`, one or more validation errors will be recorded. + */ + private fun getTypeNameFromTypeArray(type: JsonNode, path: String): String? { + val types = type.asSequence().toList() + + if (types.size == 2 && types.all { it.isTextual }) { + // Allow one type name and one "null". Be lenient about the order. Assume that there are + // no oddities like type names that are empty strings, etc., as the schemas are expected + // to be generated. + if (types[1].asText() == TYPE_NULL && types[0].asText() != TYPE_NULL) { + return types[0].asText() + } else if (types[0].asText() == TYPE_NULL && types[1].asText() != TYPE_NULL) { + return types[1].asText() + } else { + error(path) { "Expected one type name and one \"$TYPE_NULL\"." } + } + } else { + error(path) { "Expected exactly two types, both strings." } + } + + return null + } + + private inline fun verify(value: Boolean, path: String, lazyMessage: () -> Any) { + verify(value, path, lazyMessage) {} + } + + private inline fun verify( + value: Boolean, + path: String, + lazyMessage: () -> Any, + onFalse: () -> Unit, + ) { + if (!value) { + error(path, lazyMessage) + onFalse() + } + } + + private inline fun error(path: String, lazyMessage: () -> Any) { + errors.add("$path: ${lazyMessage()}") + } + + override fun toString(): String = + "${javaClass.simpleName}{isValidationComplete=$isValidationComplete, " + + "totalStringLength=$totalStringLength, " + + "totalObjectProperties=$totalObjectProperties, " + + "totalEnumValues=$totalEnumValues, errors=$errors}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt new file mode 100644 index 00000000..df48c1af --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -0,0 +1,132 @@ +package com.openai.core + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule +import com.fasterxml.jackson.module.kotlin.kotlinModule +import com.github.victools.jsonschema.generator.Option +import com.github.victools.jsonschema.generator.OptionPreset +import com.github.victools.jsonschema.generator.SchemaGenerator +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder +import com.github.victools.jsonschema.module.jackson.JacksonModule +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.ResponseFormatJsonSchema +import com.openai.models.responses.ResponseFormatTextJsonSchemaConfig +import com.openai.models.responses.ResponseTextConfig + +// The SDK `ObjectMappers.jsonMapper()` requires that all fields of classes be marked with +// `@JsonProperty`, which is not desirable in this context, as it impedes usability. Therefore, a +// custom JSON mapper configuration is required. +private val MAPPER = + JsonMapper.builder() + .addModule(kotlinModule()) + .addModule(Jdk8Module()) + .addModule(JavaTimeModule()) + .build() + +/** + * Builds a response format using a JSON schema derived from the structure of an arbitrary Java + * class. + */ +@JvmSynthetic +internal fun responseFormatFromClass( + type: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, +): ResponseFormatJsonSchema = + ResponseFormatJsonSchema.builder() + .jsonSchema( + ResponseFormatJsonSchema.JsonSchema.builder() + .name("json-schema-from-${type.simpleName}") + .schema(JsonValue.fromJsonNode(extractAndValidateSchema(type, localValidation))) + // Ensure the model's output strictly adheres to this JSON schema. This is the + // essential "ON switch" for Structured Outputs. + .strict(true) + .build() + ) + .build() + +private fun extractAndValidateSchema( + type: Class, + localValidation: JsonSchemaLocalValidation, +): JsonNode { + val schema = extractSchema(type) + + if (localValidation == JsonSchemaLocalValidation.YES) { + val validator = JsonSchemaValidator.create().validate(schema) + + require(validator.isValid()) { + "Local validation failed for JSON schema derived from $type:\n" + + validator.errors().joinToString("\n") { " - $it" } + } + } + + return schema +} + +/** + * Builds a text configuration with its format set to a JSON schema derived from the structure of an + * arbitrary Java class. + */ +@JvmSynthetic +internal fun textConfigFromClass( + type: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, +): ResponseTextConfig = + ResponseTextConfig.builder() + .format( + ResponseFormatTextJsonSchemaConfig.builder() + .name("json-schema-from-${type.simpleName}") + .schema(JsonValue.fromJsonNode(extractAndValidateSchema(type, localValidation))) + // Ensure the model's output strictly adheres to this JSON schema. This is the + // essential "ON switch" for Structured Outputs. + .strict(true) + .build() + ) + .build() + +/** + * Derives a JSON schema from the structure of an arbitrary Java class. + * + * Validation is not performed by this function, as it allows extraction of the schema and + * validation of the schema to be controlled more easily when unit testing, as no exceptions will be + * thrown and any recorded validation errors can be inspected at leisure by the tests. + */ +@JvmSynthetic +internal fun extractSchema(type: Class): JsonNode { + val configBuilder = + SchemaGeneratorConfigBuilder( + com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON, + ) + // Add `"additionalProperties" : false` to all object schemas (see OpenAI). + .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT) + // Use `JacksonModule` to support the use of Jackson annotations to set property and + // class names and descriptions and to mark fields with `@JsonIgnore`. + .with(JacksonModule()) + + configBuilder + .forFields() + // For OpenAI schemas, _all_ properties _must_ be required. Override the interpretation of + // the Jackson `required` parameter to the `@JsonProperty` annotation: it will always be + // assumed to be `true`, even if explicitly `false` and even if there is no `@JsonProperty` + // annotation present. + .withRequiredCheck { true } + + return SchemaGenerator(configBuilder.build()).generateSchema(type) +} + +/** + * Creates an instance of a Java class using data from a JSON. The JSON data should conform to the + * JSON schema previously extracted from the Java class. + */ +@JvmSynthetic +internal fun responseTypeFromJson(json: String, responseType: Class): T = + try { + MAPPER.readValue(json, responseType) + } catch (e: Exception) { + // The JSON document is included in the exception message to aid diagnosis of the problem. + // It is the responsibility of the SDK user to ensure that exceptions that may contain + // sensitive data are not exposed in logs. + throw OpenAIInvalidDataException("Error parsing JSON: $json", e) + } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index a3281dc6..85d2582f 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -19,6 +19,7 @@ import com.openai.core.Enum import com.openai.core.ExcludeMissing import com.openai.core.JsonField import com.openai.core.JsonMissing +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.Params import com.openai.core.allMaxBy @@ -1297,6 +1298,31 @@ private constructor( body.responseFormat(jsonObject) } + /** + * Sets the response format to a JSON schema derived from the structure of the given class. + * This changes the builder to a type-safe [StructuredChatCompletionCreateParams.Builder] + * that will build a [StructuredChatCompletionCreateParams] instance when `build()` is + * called. + * + * @param responseType A class from which a JSON schema will be derived to define the + * response format. + * @param localValidation [com.openai.core.JsonSchemaLocalValidation.YES] (the default) to + * validate the JSON schema locally when it is generated by this method to confirm that it + * adheres to the requirements and restrictions on JSON schemas imposed by the OpenAI + * specification; or [com.openai.core.JsonSchemaLocalValidation.NO] to skip local + * validation and rely only on remote validation. See the SDK documentation for more + * details. + * @throws IllegalArgumentException If local validation is enabled, but it fails because a + * valid JSON schema cannot be derived from the given class. + */ + @JvmOverloads + fun responseFormat( + responseType: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = + StructuredChatCompletionCreateParams.builder() + .wrap(responseType, this, localValidation) + /** * This feature is in Beta. If specified, our system will make a best effort to sample * deterministically, such that repeated requests with the same `seed` and parameters should diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt new file mode 100644 index 00000000..872135a4 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt @@ -0,0 +1,176 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.chat.completions.ChatCompletion.Choice.FinishReason +import com.openai.models.chat.completions.ChatCompletion.Choice.Logprobs +import com.openai.models.chat.completions.ChatCompletion.ServiceTier +import com.openai.models.completions.CompletionUsage +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [ChatCompletion] that provides type-safe access to the [choices] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the response will be deserialized. + */ +class StructuredChatCompletion( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawChatCompletion") val rawChatCompletion: ChatCompletion, +) { + /** @see ChatCompletion.id */ + fun id(): String = rawChatCompletion.id() + + private val choices by lazy { + rawChatCompletion._choices().map { choices -> choices.map { Choice(responseType, it) } } + } + + /** @see ChatCompletion.choices */ + fun choices(): List> = choices.getRequired("choices") + + /** @see ChatCompletion.created */ + fun created(): Long = rawChatCompletion.created() + + /** @see ChatCompletion.model */ + fun model(): String = rawChatCompletion.model() + + /** @see ChatCompletion._object_ */ + fun _object_(): JsonValue = rawChatCompletion._object_() + + /** @see ChatCompletion.serviceTier */ + fun serviceTier(): Optional = rawChatCompletion.serviceTier() + + /** @see ChatCompletion.systemFingerprint */ + fun systemFingerprint(): Optional = rawChatCompletion.systemFingerprint() + + /** @see ChatCompletion.usage */ + fun usage(): Optional = rawChatCompletion.usage() + + /** @see ChatCompletion._id */ + fun _id(): JsonField = rawChatCompletion._id() + + /** @see ChatCompletion._choices */ + fun _choices(): JsonField>> = choices + + /** @see ChatCompletion._created */ + fun _created(): JsonField = rawChatCompletion._created() + + /** @see ChatCompletion._model */ + fun _model(): JsonField = rawChatCompletion._model() + + /** @see ChatCompletion._serviceTier */ + fun _serviceTier(): JsonField = rawChatCompletion._serviceTier() + + /** @see ChatCompletion._systemFingerprint */ + fun _systemFingerprint(): JsonField = rawChatCompletion._systemFingerprint() + + /** @see ChatCompletion._usage */ + fun _usage(): JsonField = rawChatCompletion._usage() + + /** @see ChatCompletion._additionalProperties */ + fun _additionalProperties(): Map = rawChatCompletion._additionalProperties() + + class Choice + internal constructor( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawChoice") val rawChoice: ChatCompletion.Choice, + ) { + /** @see ChatCompletion.Choice.finishReason */ + fun finishReason(): FinishReason = rawChoice.finishReason() + + /** @see ChatCompletion.Choice.index */ + fun index(): Long = rawChoice.index() + + /** @see ChatCompletion.Choice.logprobs */ + fun logprobs(): Optional = rawChoice.logprobs() + + /** @see ChatCompletion.Choice._finishReason */ + fun _finishReason(): JsonField = rawChoice._finishReason() + + private val message by lazy { + rawChoice._message().map { StructuredChatCompletionMessage(responseType, it) } + } + + /** @see ChatCompletion.Choice.message */ + fun message(): StructuredChatCompletionMessage = message.getRequired("message") + + /** @see ChatCompletion.Choice._index */ + fun _index(): JsonField = rawChoice._index() + + /** @see ChatCompletion.Choice._logprobs */ + fun _logprobs(): JsonField = rawChoice._logprobs() + + /** @see ChatCompletion.Choice._message */ + fun _message(): JsonField> = message + + /** @see ChatCompletion.Choice._additionalProperties */ + fun _additionalProperties(): Map = rawChoice._additionalProperties() + + /** @see ChatCompletion.Choice.validate */ + fun validate(): Choice = apply { + message().validate() + rawChoice.validate() + } + + /** @see ChatCompletion.Choice.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is Choice<*> && + responseType == other.responseType && + rawChoice == other.rawChoice + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawChoice) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawChoice=$rawChoice}" + } + + /** @see ChatCompletion.validate */ + fun validate() = apply { + choices().forEach { it.validate() } + rawChatCompletion.validate() + } + + /** @see ChatCompletion.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletion<*> && + responseType == other.responseType && + rawChatCompletion == other.rawChatCompletion + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawChatCompletion) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawChatCompletion=$rawChatCompletion}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt new file mode 100644 index 00000000..73741b58 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -0,0 +1,765 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonSchemaLocalValidation +import com.openai.core.JsonValue +import com.openai.core.checkRequired +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.core.responseFormatFromClass +import com.openai.models.ChatModel +import com.openai.models.ReasoningEffort +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [ChatCompletionCreateParams] that provides a type-safe [Builder] that can record + * the [responseType] used to derive a JSON schema from an arbitrary class when using the + * _Structured Outputs_ feature. When a JSON response is received, it is deserialized to am instance + * of that type. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class that will be used to derive the JSON schema in the request and to + * which the JSON response will be deserialized. + */ +class StructuredChatCompletionCreateParams +internal constructor( + @get:JvmName("responseType") val responseType: Class, + /** + * The raw, underlying chat completion create parameters wrapped by this structured instance of + * the parameters. + */ + @get:JvmName("rawParams") val rawParams: ChatCompletionCreateParams, +) { + + companion object { + @JvmStatic fun builder() = Builder() + } + + class Builder internal constructor() { + private var responseType: Class? = null + private var paramsBuilder = ChatCompletionCreateParams.builder() + + @JvmSynthetic + internal fun wrap( + responseType: Class, + paramsBuilder: ChatCompletionCreateParams.Builder, + localValidation: JsonSchemaLocalValidation, + ) = apply { + this.responseType = responseType + this.paramsBuilder = paramsBuilder + // Convert the class to a JSON schema and apply it to the delegate `Builder`. + responseFormat(responseType, localValidation) + } + + /** Injects a given `ChatCompletionCreateParams.Builder`. For use only when testing. */ + @JvmSynthetic + internal fun inject(paramsBuilder: ChatCompletionCreateParams.Builder) = apply { + this.paramsBuilder = paramsBuilder + } + + /** @see ChatCompletionCreateParams.Builder.body */ + fun body(body: ChatCompletionCreateParams.Body) = apply { paramsBuilder.body(body) } + + /** @see ChatCompletionCreateParams.Builder.messages */ + fun messages(messages: List) = apply { + paramsBuilder.messages(messages) + } + + /** @see ChatCompletionCreateParams.Builder.messages */ + fun messages(messages: JsonField>) = apply { + paramsBuilder.messages(messages) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(message: ChatCompletionMessageParam) = apply { + paramsBuilder.addMessage(message) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(developer: ChatCompletionDeveloperMessageParam) = apply { + paramsBuilder.addMessage(developer) + } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessage */ + fun addDeveloperMessage(content: ChatCompletionDeveloperMessageParam.Content) = apply { + paramsBuilder.addDeveloperMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessage */ + fun addDeveloperMessage(text: String) = apply { paramsBuilder.addDeveloperMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessageOfArrayOfContentParts */ + fun addDeveloperMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addDeveloperMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(system: ChatCompletionSystemMessageParam) = apply { + paramsBuilder.addMessage(system) + } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessage */ + fun addSystemMessage(content: ChatCompletionSystemMessageParam.Content) = apply { + paramsBuilder.addSystemMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessage */ + fun addSystemMessage(text: String) = apply { paramsBuilder.addSystemMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessageOfArrayOfContentParts */ + fun addSystemMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addSystemMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(user: ChatCompletionUserMessageParam) = apply { + paramsBuilder.addMessage(user) + } + + /** @see ChatCompletionCreateParams.Builder.addUserMessage */ + fun addUserMessage(content: ChatCompletionUserMessageParam.Content) = apply { + paramsBuilder.addUserMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addUserMessage */ + fun addUserMessage(text: String) = apply { paramsBuilder.addUserMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addUserMessageOfArrayOfContentParts */ + fun addUserMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addUserMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(assistant: ChatCompletionAssistantMessageParam) = apply { + paramsBuilder.addMessage(assistant) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(assistant: ChatCompletionMessage) = apply { + paramsBuilder.addMessage(assistant) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(tool: ChatCompletionToolMessageParam) = apply { + paramsBuilder.addMessage(tool) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + @Deprecated("deprecated") + fun addMessage(function: ChatCompletionFunctionMessageParam) = apply { + paramsBuilder.addMessage(function) + } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(model: ChatModel) = apply { paramsBuilder.model(model) } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(model: JsonField) = apply { paramsBuilder.model(model) } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(value: String) = apply { paramsBuilder.model(value) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: ChatCompletionAudioParam?) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: Optional) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: JsonField) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Double?) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Double) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Optional) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: JsonField) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCall: ChatCompletionCreateParams.FunctionCall) = apply { + paramsBuilder.functionCall(functionCall) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCall: JsonField) = apply { + paramsBuilder.functionCall(functionCall) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(mode: ChatCompletionCreateParams.FunctionCall.FunctionCallMode) = apply { + paramsBuilder.functionCall(mode) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCallOption: ChatCompletionFunctionCallOption) = apply { + paramsBuilder.functionCall(functionCallOption) + } + + /** @see ChatCompletionCreateParams.Builder.functions */ + @Deprecated("deprecated") + fun functions(functions: List) = apply { + paramsBuilder.functions(functions) + } + + /** @see ChatCompletionCreateParams.Builder.functions */ + @Deprecated("deprecated") + fun functions(functions: JsonField>) = apply { + paramsBuilder.functions(functions) + } + + /** @see ChatCompletionCreateParams.Builder.addFunction */ + @Deprecated("deprecated") + fun addFunction(function: ChatCompletionCreateParams.Function) = apply { + paramsBuilder.addFunction(function) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: ChatCompletionCreateParams.LogitBias?) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: Optional) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: JsonField) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Boolean?) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Boolean) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Optional) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: JsonField) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Long?) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Long) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Optional) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: JsonField) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Long?) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Long) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Optional) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: JsonField) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: ChatCompletionCreateParams.Metadata?) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: Optional) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: JsonField) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: List?) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: Optional>) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: JsonField>) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.addModality */ + fun addModality(modality: ChatCompletionCreateParams.Modality) = apply { + paramsBuilder.addModality(modality) + } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Long?) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Long) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Optional) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: JsonField) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ChatCompletionCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: JsonField) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: ChatCompletionPredictionContent?) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: Optional) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: JsonField) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Double?) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Double) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Optional) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: JsonField) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: ReasoningEffort?) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: Optional) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: JsonField) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** + * Sets the response format to a JSON schema derived from the structure of the given class. + * + * @see ChatCompletionCreateParams.Builder.responseFormat + */ + @JvmOverloads + fun responseFormat( + responseType: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { + this.responseType = responseType + paramsBuilder.responseFormat(responseFormatFromClass(responseType, localValidation)) + } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Long?) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Long) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Optional) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: JsonField) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: ChatCompletionCreateParams.ServiceTier?) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: Optional) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: JsonField) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: ChatCompletionCreateParams.Stop?) = apply { paramsBuilder.stop(stop) } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: Optional) = apply { + paramsBuilder.stop(stop) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: JsonField) = apply { + paramsBuilder.stop(stop) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(string: String) = apply { paramsBuilder.stop(string) } + + /** @see ChatCompletionCreateParams.Builder.stopOfStrings */ + fun stopOfStrings(strings: List) = apply { paramsBuilder.stopOfStrings(strings) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Boolean?) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Boolean) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Optional) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: JsonField) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: ChatCompletionStreamOptions?) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: Optional) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: JsonField) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Double?) = apply { paramsBuilder.temperature(temperature) } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Double) = apply { paramsBuilder.temperature(temperature) } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Optional) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: JsonField) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: ChatCompletionToolChoiceOption) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: JsonField) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(auto: ChatCompletionToolChoiceOption.Auto) = apply { + paramsBuilder.toolChoice(auto) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(namedToolChoice: ChatCompletionNamedToolChoice) = apply { + paramsBuilder.toolChoice(namedToolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.tools */ + fun tools(tools: List) = apply { paramsBuilder.tools(tools) } + + /** @see ChatCompletionCreateParams.Builder.tools */ + fun tools(tools: JsonField>) = apply { paramsBuilder.tools(tools) } + + /** @see ChatCompletionCreateParams.Builder.addTool */ + fun addTool(tool: ChatCompletionTool) = apply { paramsBuilder.addTool(tool) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Long?) = apply { paramsBuilder.topLogprobs(topLogprobs) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Long) = apply { paramsBuilder.topLogprobs(topLogprobs) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Optional) = apply { + paramsBuilder.topLogprobs(topLogprobs) + } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: JsonField) = apply { + paramsBuilder.topLogprobs(topLogprobs) + } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Double?) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Double) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Optional) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: JsonField) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.user */ + fun user(user: String) = apply { paramsBuilder.user(user) } + + /** @see ChatCompletionCreateParams.Builder.user */ + fun user(user: JsonField) = apply { paramsBuilder.user(user) } + + /** @see ChatCompletionCreateParams.Builder.webSearchOptions */ + fun webSearchOptions(webSearchOptions: ChatCompletionCreateParams.WebSearchOptions) = + apply { + paramsBuilder.webSearchOptions(webSearchOptions) + } + + /** @see ChatCompletionCreateParams.Builder.webSearchOptions */ + fun webSearchOptions( + webSearchOptions: JsonField + ) = apply { paramsBuilder.webSearchOptions(webSearchOptions) } + + /** @see ChatCompletionCreateParams.Builder.additionalBodyProperties */ + fun additionalBodyProperties(additionalBodyProperties: Map) = apply { + paramsBuilder.additionalBodyProperties(additionalBodyProperties) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalBodyProperty */ + fun putAdditionalBodyProperty(key: String, value: JsonValue) = apply { + paramsBuilder.putAdditionalBodyProperty(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalBodyProperties */ + fun putAllAdditionalBodyProperties(additionalBodyProperties: Map) = + apply { + paramsBuilder.putAllAdditionalBodyProperties(additionalBodyProperties) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalBodyProperty */ + fun removeAdditionalBodyProperty(key: String) = apply { + paramsBuilder.removeAdditionalBodyProperty(key) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalBodyProperties */ + fun removeAllAdditionalBodyProperties(keys: Set) = apply { + paramsBuilder.removeAllAdditionalBodyProperties(keys) + } + + /** @see ChatCompletionCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalHeader */ + fun putAdditionalHeader(name: String, value: String) = apply { + paramsBuilder.putAdditionalHeader(name, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalHeaders */ + fun putAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.putAdditionalHeaders(name, values) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, value: String) = apply { + paramsBuilder.replaceAdditionalHeaders(name, value) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalHeaders(name, values) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalHeaders */ + fun removeAdditionalHeaders(name: String) = apply { + paramsBuilder.removeAdditionalHeaders(name) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalHeaders */ + fun removeAllAdditionalHeaders(names: Set) = apply { + paramsBuilder.removeAllAdditionalHeaders(names) + } + + /** @see ChatCompletionCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: Map>) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalQueryParam */ + fun putAdditionalQueryParam(key: String, value: String) = apply { + paramsBuilder.putAdditionalQueryParam(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalQueryParams */ + fun putAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.putAdditionalQueryParams(key, values) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, value: String) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, values) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalQueryParams */ + fun removeAdditionalQueryParams(key: String) = apply { + paramsBuilder.removeAdditionalQueryParams(key) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalQueryParams */ + fun removeAllAdditionalQueryParams(keys: Set) = apply { + paramsBuilder.removeAllAdditionalQueryParams(keys) + } + + /** + * Returns an immutable instance of [StructuredChatCompletionCreateParams]. + * + * Further updates to this [Builder] will not mutate the returned instance. + * + * The following fields are required: + * ```java + * .messages() + * .model() + * .responseFormat() + * ``` + * + * @throws IllegalStateException If any required field is unset. + */ + fun build() = + StructuredChatCompletionCreateParams( + checkRequired("responseType", responseType), + paramsBuilder.build(), + ) + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletionCreateParams<*> && + responseType == other.responseType && + rawParams == other.rawParams + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawParams) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawParams=$rawParams}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt new file mode 100644 index 00000000..27e8a3c7 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt @@ -0,0 +1,95 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.responseTypeFromJson +import com.openai.models.chat.completions.ChatCompletionMessage.FunctionCall +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [ChatCompletionMessage] that provides type-safe access to the [content] when using + * the _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary + * class. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * [content] is called. + */ +class StructuredChatCompletionMessage +internal constructor( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawMessage") val rawMessage: ChatCompletionMessage, +) { + + private val content: JsonField by lazy { + rawMessage._content().map { responseTypeFromJson(it, responseType) } + } + + /** @see ChatCompletionMessage.content */ + fun content(): Optional = content.getOptional("content") + + /** @see ChatCompletionMessage.refusal */ + fun refusal(): Optional = rawMessage.refusal() + + /** @see ChatCompletionMessage._role */ + fun _role(): JsonValue = rawMessage._role() + + /** @see ChatCompletionMessage.annotations */ + fun annotations(): Optional> = rawMessage.annotations() + + /** @see ChatCompletionMessage.audio */ + fun audio(): Optional = rawMessage.audio() + + /** @see ChatCompletionMessage.functionCall */ + @Deprecated("deprecated") fun functionCall(): Optional = rawMessage.functionCall() + + /** @see ChatCompletionMessage.toolCalls */ + fun toolCalls(): Optional> = rawMessage.toolCalls() + + /** @see ChatCompletionMessage._content */ + fun _content(): JsonField = content + + /** @see ChatCompletionMessage._refusal */ + fun _refusal(): JsonField = rawMessage._refusal() + + /** @see ChatCompletionMessage._annotations */ + fun _annotations(): JsonField> = + rawMessage._annotations() + + /** @see ChatCompletionMessage._audio */ + fun _audio(): JsonField = rawMessage._audio() + + /** @see ChatCompletionMessage._functionCall */ + @Deprecated("deprecated") + fun _functionCall(): JsonField = rawMessage._functionCall() + + /** @see ChatCompletionMessage._toolCalls */ + fun _toolCalls(): JsonField> = rawMessage._toolCalls() + + /** @see ChatCompletionMessage._additionalProperties */ + fun _additionalProperties(): Map = rawMessage._additionalProperties() + + /** @see ChatCompletionMessage.validate */ + // `content()` is not included in the validation by the delegate method, so just call it. + fun validate(): ChatCompletionMessage = rawMessage.validate() + + /** @see ChatCompletionMessage.isValid */ + fun isValid(): Boolean = rawMessage.isValid() + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletionMessage<*> && + responseType == other.responseType && + rawMessage == other.rawMessage + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawMessage) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawMessage=$rawMessage}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt index 6b512ed9..345f6a93 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt @@ -19,6 +19,7 @@ import com.openai.core.Enum import com.openai.core.ExcludeMissing import com.openai.core.JsonField import com.openai.core.JsonMissing +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.Params import com.openai.core.allMaxBy @@ -784,6 +785,28 @@ private constructor( */ fun text(text: JsonField) = apply { body.text(text) } + /** + * Sets the text configuration's format to a JSON schema derived from the structure of the + * given class. This changes the builder to a type-safe + * [StructuredResponseCreateParams.Builder] that will build a + * [StructuredResponseCreateParams] instance when `build()` is called. + * + * @param responseType A class from which a JSON schema will be derived to define the text + * configuration's format. + * @param localValidation [JsonSchemaLocalValidation.YES] (the default) to validate the JSON + * schema locally when it is generated by this method to confirm that it adheres to the + * requirements and restrictions on JSON schemas imposed by the OpenAI specification; or + * [JsonSchemaLocalValidation.NO] to skip local validation and rely only on remote + * validation. See the SDK documentation for more details. + * @throws IllegalArgumentException If local validation is enabled, but it fails because a + * valid JSON schema cannot be derived from the given class. + */ + @JvmOverloads + fun text( + responseType: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = StructuredResponseCreateParams.builder().wrap(responseType, this, localValidation) + /** * How the model should select which tool (or tools) to use when generating a response. See * the `tools` parameter to see how to specify which tools the model can call. diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt new file mode 100644 index 00000000..f28865b0 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponse.kt @@ -0,0 +1,197 @@ +package com.openai.models.responses + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.Reasoning +import com.openai.models.ResponsesModel +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [Response] that provides type-safe access to the [output] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the response will be deserialized. + */ +class StructuredResponse( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawResponse") val rawResponse: Response, +) { + /** @see Response.id */ + fun id(): String = rawResponse.id() + + /** @see Response.createdAt */ + fun createdAt(): Double = rawResponse.createdAt() + + /** @see Response.error */ + fun error(): Optional = rawResponse.error() + + /** @see Response.incompleteDetails */ + fun incompleteDetails(): Optional = rawResponse.incompleteDetails() + + /** @see Response.instructions */ + fun instructions(): Optional = rawResponse.instructions() + + /** @see Response.metadata */ + fun metadata(): Optional = rawResponse.metadata() + + /** @see Response.model */ + fun model(): ResponsesModel = rawResponse.model() + + /** @see Response._object_ */ + fun _object_(): JsonValue = rawResponse._object_() + + private val output by lazy { + rawResponse._output().map { outputs -> + outputs.map { StructuredResponseOutputItem(responseType, it) } + } + } + + /** @see Response.output */ + fun output(): List> = output.getRequired("output") + + /** @see Response.parallelToolCalls */ + fun parallelToolCalls(): Boolean = rawResponse.parallelToolCalls() + + /** @see Response.temperature */ + fun temperature(): Optional = rawResponse.temperature() + + /** @see Response.toolChoice */ + fun toolChoice(): Response.ToolChoice = rawResponse.toolChoice() + + /** @see Response.tools */ + fun tools(): List = rawResponse.tools() + + /** @see Response.topP */ + fun topP(): Optional = rawResponse.topP() + + /** @see Response.maxOutputTokens */ + fun maxOutputTokens(): Optional = rawResponse.maxOutputTokens() + + /** @see Response.previousResponseId */ + fun previousResponseId(): Optional = rawResponse.previousResponseId() + + /** @see Response.reasoning */ + fun reasoning(): Optional = rawResponse.reasoning() + + /** @see Response.serviceTier */ + fun serviceTier(): Optional = rawResponse.serviceTier() + + /** @see Response.status */ + fun status(): Optional = rawResponse.status() + + /** @see Response.text */ + fun text(): Optional = rawResponse.text() + + /** @see Response.truncation */ + fun truncation(): Optional = rawResponse.truncation() + + /** @see Response.usage */ + fun usage(): Optional = rawResponse.usage() + + /** @see Response.user */ + fun user(): Optional = rawResponse.user() + + /** @see Response._id */ + fun _id(): JsonField = rawResponse._id() + + /** @see Response._createdAt */ + fun _createdAt(): JsonField = rawResponse._createdAt() + + /** @see Response._error */ + fun _error(): JsonField = rawResponse._error() + + /** @see Response._incompleteDetails */ + fun _incompleteDetails(): JsonField = + rawResponse._incompleteDetails() + + /** @see Response._instructions */ + fun _instructions(): JsonField = rawResponse._instructions() + + /** @see Response._metadata */ + fun _metadata(): JsonField = rawResponse._metadata() + + /** @see Response._model */ + fun _model(): JsonField = rawResponse._model() + + /** @see Response._output */ + fun _output(): JsonField>> = output + + /** @see Response._parallelToolCalls */ + fun _parallelToolCalls(): JsonField = rawResponse._parallelToolCalls() + + /** @see Response._temperature */ + fun _temperature(): JsonField = rawResponse._temperature() + + /** @see Response._toolChoice */ + fun _toolChoice(): JsonField = rawResponse._toolChoice() + + /** @see Response._tools */ + fun _tools(): JsonField> = rawResponse._tools() + + /** @see Response._topP */ + fun _topP(): JsonField = rawResponse._topP() + + /** @see Response._maxOutputTokens */ + fun _maxOutputTokens(): JsonField = rawResponse._maxOutputTokens() + + /** @see Response._previousResponseId */ + fun _previousResponseId(): JsonField = rawResponse._previousResponseId() + + /** @see Response._reasoning */ + fun _reasoning(): JsonField = rawResponse._reasoning() + + /** @see Response._serviceTier */ + fun _serviceTier(): JsonField = rawResponse._serviceTier() + + /** @see Response._status */ + fun _status(): JsonField = rawResponse._status() + + /** @see Response._text */ + fun _text(): JsonField = rawResponse._text() + + /** @see Response._truncation */ + fun _truncation(): JsonField = rawResponse._truncation() + + /** @see Response._usage */ + fun _usage(): JsonField = rawResponse._usage() + + /** @see Response._user */ + fun _user(): JsonField = rawResponse._user() + + /** @see Response._additionalProperties */ + fun _additionalProperties(): Map = rawResponse._additionalProperties() + + /** @see Response.validate */ + fun validate(): StructuredResponse = apply { + output().forEach { it.validate() } + rawResponse.validate() + } + + /** @see Response.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (e: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + return other is StructuredResponse<*> && + responseType == other.responseType && + rawResponse == other.rawResponse + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawResponse) } + + override fun hashCode(): Int = hashCode + + override fun toString(): String = + "${javaClass.simpleName}{responseType=$responseType, rawResponse=$rawResponse}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt new file mode 100644 index 00000000..aae191ac --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt @@ -0,0 +1,526 @@ +package com.openai.models.responses + +import com.openai.core.JsonField +import com.openai.core.JsonSchemaLocalValidation +import com.openai.core.JsonValue +import com.openai.core.checkRequired +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.core.textConfigFromClass +import com.openai.models.ChatModel +import com.openai.models.Reasoning +import com.openai.models.ResponsesModel +import java.util.Objects +import java.util.Optional + +/** + * A wrapper for [ResponseCreateParams] that provides a type-safe [Builder] that can record the + * [responseType] used to derive a JSON schema from an arbitrary class when using the _Structured + * Outputs_ feature. When a JSON response is received, it is deserialized to am instance of that + * type. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class that will be used to derive the JSON schema in the request and to + * which the JSON response will be deserialized. + */ +class StructuredResponseCreateParams( + @get:JvmName("responseType") val responseType: Class, + /** + * The raw, underlying response create parameters wrapped by this structured instance of the + * parameters. + */ + @get:JvmName("rawParams") val rawParams: ResponseCreateParams, +) { + + companion object { + /** @see ResponseCreateParams.builder */ + @JvmStatic fun builder() = Builder() + } + + class Builder internal constructor() { + private var responseType: Class? = null + private var paramsBuilder = ResponseCreateParams.builder() + + @JvmSynthetic + internal fun wrap( + responseType: Class, + paramsBuilder: ResponseCreateParams.Builder, + localValidation: JsonSchemaLocalValidation, + ) = apply { + this.responseType = responseType + this.paramsBuilder = paramsBuilder + text(responseType, localValidation) + } + + /** Injects a given `ResponseCreateParams.Builder`. For use only when testing. */ + @JvmSynthetic + internal fun inject(paramsBuilder: ResponseCreateParams.Builder) = apply { + this.paramsBuilder = paramsBuilder + } + + // The `body(...)` function is deliberately not supported. + + /** @see ResponseCreateParams.Builder.input */ + fun input(input: ResponseCreateParams.Input) = apply { paramsBuilder.input(input) } + + /** @see ResponseCreateParams.Builder.input */ + fun input(input: JsonField) = apply { + paramsBuilder.input(input) + } + + /** @see ResponseCreateParams.Builder.input */ + fun input(text: String) = apply { paramsBuilder.input(text) } + + /** @see ResponseCreateParams.Builder.inputOfResponse */ + fun inputOfResponse(response: List) = apply { + paramsBuilder.inputOfResponse(response) + } + + /** @see ResponseCreateParams.Builder.model */ + fun model(model: ResponsesModel) = apply { paramsBuilder.model(model) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(model: JsonField) = apply { paramsBuilder.model(model) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(string: String) = apply { paramsBuilder.model(string) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(chat: ChatModel) = apply { paramsBuilder.model(chat) } + + /** @see ResponseCreateParams.Builder.model */ + fun model(only: ResponsesModel.ResponsesOnlyModel) = apply { paramsBuilder.model(only) } + + /** @see ResponseCreateParams.Builder.include */ + fun include(include: List?) = apply { paramsBuilder.include(include) } + + /** @see ResponseCreateParams.Builder.include */ + fun include(include: Optional>) = apply { + paramsBuilder.include(include) + } + + /** @see ResponseCreateParams.Builder.include */ + fun include(include: JsonField>) = apply { + paramsBuilder.include(include) + } + + /** @see ResponseCreateParams.Builder.addInclude */ + fun addInclude(include: ResponseIncludable) = apply { paramsBuilder.addInclude(include) } + + /** @see ResponseCreateParams.Builder.instructions */ + fun instructions(instructions: String?) = apply { paramsBuilder.instructions(instructions) } + + /** @see ResponseCreateParams.Builder.instructions */ + fun instructions(instructions: Optional) = apply { + paramsBuilder.instructions(instructions) + } + + /** @see ResponseCreateParams.Builder.instructions */ + fun instructions(instructions: JsonField) = apply { + paramsBuilder.instructions(instructions) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: Long?) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: Long) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: Optional) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.maxOutputTokens */ + fun maxOutputTokens(maxOutputTokens: JsonField) = apply { + paramsBuilder.maxOutputTokens(maxOutputTokens) + } + + /** @see ResponseCreateParams.Builder.metadata */ + fun metadata(metadata: ResponseCreateParams.Metadata?) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ResponseCreateParams.Builder.metadata */ + fun metadata(metadata: Optional) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ResponseCreateParams.Builder.metadata */ + fun metadata(metadata: JsonField) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean?) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Optional) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: JsonField) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ResponseCreateParams.Builder.previousResponseId */ + fun previousResponseId(previousResponseId: String?) = apply { + paramsBuilder.previousResponseId(previousResponseId) + } + + /** @see ResponseCreateParams.Builder.previousResponseId */ + fun previousResponseId(previousResponseId: Optional) = apply { + paramsBuilder.previousResponseId(previousResponseId) + } + + /** @see ResponseCreateParams.Builder.previousResponseId */ + fun previousResponseId(previousResponseId: JsonField) = apply { + paramsBuilder.previousResponseId(previousResponseId) + } + + /** @see ResponseCreateParams.Builder.reasoning */ + fun reasoning(reasoning: Reasoning?) = apply { paramsBuilder.reasoning(reasoning) } + + /** @see ResponseCreateParams.Builder.reasoning */ + fun reasoning(reasoning: Optional) = apply { paramsBuilder.reasoning(reasoning) } + + /** @see ResponseCreateParams.Builder.reasoning */ + fun reasoning(reasoning: JsonField) = apply { + paramsBuilder.reasoning(reasoning) + } + + /** @see ResponseCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: ResponseCreateParams.ServiceTier?) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ResponseCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: Optional) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ResponseCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: JsonField) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: Boolean?) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: Boolean) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: Optional) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.store */ + fun store(store: JsonField) = apply { paramsBuilder.store(store) } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: Double?) = apply { paramsBuilder.temperature(temperature) } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: Double) = apply { paramsBuilder.temperature(temperature) } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: Optional) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ResponseCreateParams.Builder.temperature */ + fun temperature(temperature: JsonField) = apply { + paramsBuilder.temperature(temperature) + } + + /** + * Sets the text configuration's format to a JSON schema derived from the structure of the + * given class. + * + * @see ResponseCreateParams.Builder.text + */ + @JvmOverloads + fun text( + responseType: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { + this.responseType = responseType + paramsBuilder.text(textConfigFromClass(responseType, localValidation)) + } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: ResponseCreateParams.ToolChoice) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: JsonField) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(options: ToolChoiceOptions) = apply { paramsBuilder.toolChoice(options) } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(types: ToolChoiceTypes) = apply { paramsBuilder.toolChoice(types) } + + /** @see ResponseCreateParams.Builder.toolChoice */ + fun toolChoice(function: ToolChoiceFunction) = apply { paramsBuilder.toolChoice(function) } + + /** @see ResponseCreateParams.Builder.tools */ + fun tools(tools: List) = apply { paramsBuilder.tools(tools) } + + /** @see ResponseCreateParams.Builder.tools */ + fun tools(tools: JsonField>) = apply { paramsBuilder.tools(tools) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(tool: Tool) = apply { paramsBuilder.addTool(tool) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(fileSearch: FileSearchTool) = apply { paramsBuilder.addTool(fileSearch) } + + /** @see ResponseCreateParams.Builder.addFileSearchTool */ + fun addFileSearchTool(vectorStoreIds: List) = apply { + paramsBuilder.addFileSearchTool(vectorStoreIds) + } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(function: FunctionTool) = apply { paramsBuilder.addTool(function) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(webSearch: WebSearchTool) = apply { paramsBuilder.addTool(webSearch) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(computerUsePreview: ComputerTool) = apply { + paramsBuilder.addTool(computerUsePreview) + } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: Double?) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: Double) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: Optional) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.topP */ + fun topP(topP: JsonField) = apply { paramsBuilder.topP(topP) } + + /** @see ResponseCreateParams.Builder.truncation */ + fun truncation(truncation: ResponseCreateParams.Truncation?) = apply { + paramsBuilder.truncation(truncation) + } + + /** @see ResponseCreateParams.Builder.truncation */ + fun truncation(truncation: Optional) = apply { + paramsBuilder.truncation(truncation) + } + + /** @see ResponseCreateParams.Builder.truncation */ + fun truncation(truncation: JsonField) = apply { + paramsBuilder.truncation(truncation) + } + + /** @see ResponseCreateParams.Builder.user */ + fun user(user: String) = apply { paramsBuilder.user(user) } + + /** @see ResponseCreateParams.Builder.user */ + fun user(user: JsonField) = apply { paramsBuilder.user(user) } + + /** @see ResponseCreateParams.Builder.additionalBodyProperties */ + fun additionalBodyProperties(additionalBodyProperties: Map) = apply { + paramsBuilder.additionalBodyProperties(additionalBodyProperties) + } + + /** @see ResponseCreateParams.Builder.putAdditionalBodyProperty */ + fun putAdditionalBodyProperty(key: String, value: JsonValue) = apply { + paramsBuilder.putAdditionalBodyProperty(key, value) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalBodyProperties */ + fun putAllAdditionalBodyProperties(additionalBodyProperties: Map) = + apply { + paramsBuilder.putAllAdditionalBodyProperties(additionalBodyProperties) + } + + /** @see ResponseCreateParams.Builder.removeAdditionalBodyProperty */ + fun removeAdditionalBodyProperty(key: String) = apply { + paramsBuilder.removeAdditionalBodyProperty(key) + } + + /** @see ResponseCreateParams.Builder.removeAllAdditionalBodyProperties */ + fun removeAllAdditionalBodyProperties(keys: Set) = apply { + paramsBuilder.removeAllAdditionalBodyProperties(keys) + } + + /** @see ResponseCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.putAdditionalHeader */ + fun putAdditionalHeader(name: String, value: String) = apply { + paramsBuilder.putAdditionalHeader(name, value) + } + + /** @see ResponseCreateParams.Builder.putAdditionalHeaders */ + fun putAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.putAdditionalHeaders(name, values) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, value: String) = apply { + paramsBuilder.replaceAdditionalHeaders(name, value) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalHeaders(name, values) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ResponseCreateParams.Builder.removeAdditionalHeaders */ + fun removeAdditionalHeaders(name: String) = apply { + paramsBuilder.removeAdditionalHeaders(name) + } + + /** @see ResponseCreateParams.Builder.removeAllAdditionalHeaders */ + fun removeAllAdditionalHeaders(names: Set) = apply { + paramsBuilder.removeAllAdditionalHeaders(names) + } + + /** @see ResponseCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: Map>) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.putAdditionalQueryParam */ + fun putAdditionalQueryParam(key: String, value: String) = apply { + paramsBuilder.putAdditionalQueryParam(key, value) + } + + /** @see ResponseCreateParams.Builder.putAdditionalQueryParams */ + fun putAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.putAdditionalQueryParams(key, values) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, value: String) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, value) + } + + /** @see ResponseCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, values) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ResponseCreateParams.Builder.removeAdditionalQueryParams */ + fun removeAdditionalQueryParams(key: String) = apply { + paramsBuilder.removeAdditionalQueryParams(key) + } + + /** @see ResponseCreateParams.Builder.removeAllAdditionalQueryParams */ + fun removeAllAdditionalQueryParams(keys: Set) = apply { + paramsBuilder.removeAllAdditionalQueryParams(keys) + } + + /** + * Returns an immutable instance of [ResponseCreateParams]. + * + * Further updates to this [Builder] will not mutate the returned instance. + * + * The following fields are required: + * ```java + * .input() + * .model() + * .text() + * ``` + * + * @throws IllegalStateException if any required field is unset. + */ + fun build() = + StructuredResponseCreateParams( + checkRequired("responseType", responseType), + paramsBuilder.build(), + ) + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredResponseCreateParams<*> && + responseType == other.responseType && + rawParams == other.rawParams + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawParams) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawParams=$rawParams}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt new file mode 100644 index 00000000..39cd6b6b --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputItem.kt @@ -0,0 +1,189 @@ +package com.openai.models.responses + +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import java.util.Objects +import java.util.Optional +import kotlin.jvm.optionals.getOrElse +import kotlin.jvm.optionals.getOrNull + +/** + * A wrapper for [ResponseOutputItem] that provides type-safe access to the [message] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * [message] is called. + */ +class StructuredResponseOutputItem( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawOutputItem") val rawOutputItem: ResponseOutputItem, +) { + private val message by lazy { + rawOutputItem.message().map { StructuredResponseOutputMessage(responseType, it) } + } + + /** @see ResponseOutputItem.message */ + fun message(): Optional> = message + + /** @see ResponseOutputItem.fileSearchCall */ + fun fileSearchCall(): Optional = rawOutputItem.fileSearchCall() + + /** @see ResponseOutputItem.functionCall */ + fun functionCall(): Optional = rawOutputItem.functionCall() + + /** @see ResponseOutputItem.webSearchCall */ + fun webSearchCall(): Optional = rawOutputItem.webSearchCall() + + /** @see ResponseOutputItem.computerCall */ + fun computerCall(): Optional = rawOutputItem.computerCall() + + /** @see ResponseOutputItem.reasoning */ + fun reasoning(): Optional = rawOutputItem.reasoning() + + /** @see ResponseOutputItem.isMessage */ + fun isMessage(): Boolean = message().isPresent + + /** @see ResponseOutputItem.isFileSearchCall */ + fun isFileSearchCall(): Boolean = rawOutputItem.isFileSearchCall() + + /** @see ResponseOutputItem.isFunctionCall */ + fun isFunctionCall(): Boolean = rawOutputItem.isFunctionCall() + + /** @see ResponseOutputItem.isWebSearchCall */ + fun isWebSearchCall(): Boolean = rawOutputItem.isWebSearchCall() + + /** @see ResponseOutputItem.isComputerCall */ + fun isComputerCall(): Boolean = rawOutputItem.isComputerCall() + + /** @see ResponseOutputItem.isReasoning */ + fun isReasoning(): Boolean = rawOutputItem.isReasoning() + + /** @see ResponseOutputItem.asMessage */ + fun asMessage(): StructuredResponseOutputMessage = + message.getOrElse { + // Same behavior as `com.openai.core.getOrThrow` used by the delegate class. + throw OpenAIInvalidDataException("`message` is not present") + } + + /** @see ResponseOutputItem.asFileSearchCall */ + fun asFileSearchCall(): ResponseFileSearchToolCall = rawOutputItem.asFileSearchCall() + + /** @see ResponseOutputItem.asFunctionCall */ + fun asFunctionCall(): ResponseFunctionToolCall = rawOutputItem.asFunctionCall() + + /** @see ResponseOutputItem.asWebSearchCall */ + fun asWebSearchCall(): ResponseFunctionWebSearch = rawOutputItem.asWebSearchCall() + + /** @see ResponseOutputItem.asComputerCall */ + fun asComputerCall(): ResponseComputerToolCall = rawOutputItem.asComputerCall() + + /** @see ResponseOutputItem.asReasoning */ + fun asReasoning(): ResponseReasoningItem = rawOutputItem.asReasoning() + + /** @see ResponseOutputItem._json */ + fun _json(): Optional = rawOutputItem._json() + + /** @see ResponseOutputItem.accept */ + fun accept(visitor: Visitor): R = + when { + isMessage() -> visitor.visitMessage(asMessage()) + isFileSearchCall() -> visitor.visitFileSearchCall(asFileSearchCall()) + isFunctionCall() -> visitor.visitFunctionCall(asFunctionCall()) + isWebSearchCall() -> visitor.visitWebSearchCall(asWebSearchCall()) + isComputerCall() -> visitor.visitComputerCall(asComputerCall()) + isReasoning() -> visitor.visitReasoning(asReasoning()) + else -> visitor.unknown(_json().getOrNull()) + } + + private var validated: Boolean = false + + /** @see ResponseOutputItem.validate */ + fun validate(): StructuredResponseOutputItem = apply { + if (validated) { + return@apply + } + + accept( + object : Visitor { + override fun visitMessage(message: StructuredResponseOutputMessage) { + message.validate() + } + + override fun visitFileSearchCall(fileSearchCall: ResponseFileSearchToolCall) { + fileSearchCall.validate() + } + + override fun visitFunctionCall(functionCall: ResponseFunctionToolCall) { + functionCall.validate() + } + + override fun visitWebSearchCall(webSearchCall: ResponseFunctionWebSearch) { + webSearchCall.validate() + } + + override fun visitComputerCall(computerCall: ResponseComputerToolCall) { + computerCall.validate() + } + + override fun visitReasoning(reasoning: ResponseReasoningItem) { + reasoning.validate() + } + } + ) + validated = true + } + + /** @see ResponseOutputItem.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredResponseOutputItem<*> && + responseType == other.responseType && + rawOutputItem == other.rawOutputItem + } + + override fun hashCode(): Int = Objects.hash(responseType, rawOutputItem) + + override fun toString(): String = + "${javaClass.simpleName}{responseType=$responseType, rawOutputItem=$rawOutputItem}" + + /** @see ResponseOutputItem.Visitor */ + // In keeping with the delegate's `Visitor`, `T` is used to refer to the return type of each + // function. `R` (for "Response") is used to refer to the response type, which is otherwise + // named `T` in the outer class, but confusion here is probably preferable to confusion there. + interface Visitor { + /** @see ResponseOutputItem.Visitor.visitMessage */ + fun visitMessage(message: StructuredResponseOutputMessage): T + + /** @see ResponseOutputItem.Visitor.visitFileSearchCall */ + fun visitFileSearchCall(fileSearchCall: ResponseFileSearchToolCall): T + + /** @see ResponseOutputItem.Visitor.visitFunctionCall */ + fun visitFunctionCall(functionCall: ResponseFunctionToolCall): T + + /** @see ResponseOutputItem.Visitor.visitWebSearchCall */ + fun visitWebSearchCall(webSearchCall: ResponseFunctionWebSearch): T + + /** @see ResponseOutputItem.Visitor.visitComputerCall */ + fun visitComputerCall(computerCall: ResponseComputerToolCall): T + + /** @see ResponseOutputItem.Visitor.visitReasoning */ + fun visitReasoning(reasoning: ResponseReasoningItem): T + + /** @see ResponseOutputItem.Visitor.unknown */ + fun unknown(json: JsonValue?): T { + throw OpenAIInvalidDataException("Unknown ResponseOutputItem: $json") + } + } +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt new file mode 100644 index 00000000..b7083a25 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseOutputMessage.kt @@ -0,0 +1,181 @@ +package com.openai.models.responses + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.responseTypeFromJson +import com.openai.errors.OpenAIInvalidDataException +import java.util.Objects +import java.util.Optional +import kotlin.jvm.optionals.getOrElse +import kotlin.jvm.optionals.getOrNull + +/** + * A wrapper for [ResponseOutputMessage] that provides type-safe access to the [content] when using + * the _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary + * class. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * the output text of the [content] is retrieved. + */ +class StructuredResponseOutputMessage( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawMessage") val rawMessage: ResponseOutputMessage, +) { + /** @see ResponseOutputMessage.id */ + fun id(): String = rawMessage.id() + + private val content by lazy { + rawMessage._content().map { contents -> contents.map { Content(responseType, it) } } + } + + /** @see ResponseOutputMessage.content */ + fun content(): List> = content.getRequired("content") + + /** @see ResponseOutputMessage._role */ + fun _role(): JsonValue = rawMessage._role() + + /** @see ResponseOutputMessage.status */ + fun status(): ResponseOutputMessage.Status = rawMessage.status() + + /** @see ResponseOutputMessage._type */ + fun _type(): JsonValue = rawMessage._type() + + /** @see ResponseOutputMessage._id */ + fun _id(): JsonField = rawMessage._id() + + /** @see ResponseOutputMessage._content */ + fun _content(): JsonField>> = content + + /** @see ResponseOutputMessage._status */ + fun _status(): JsonField = rawMessage._status() + + /** @see ResponseOutputMessage._additionalProperties */ + fun _additionalProperties(): Map = rawMessage._additionalProperties() + + /** @see ResponseOutputMessage.validate */ + fun validate(): StructuredResponseOutputMessage = apply { + // `content()` is a different type to that in the delegate class. + content().forEach { it.validate() } + rawMessage.validate() + } + + /** @see ResponseOutputMessage.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + /** @see ResponseOutputMessage.Content */ + class Content( + @get:JvmName("responseType") val responseType: Class, + @get:JvmName("rawContent") val rawContent: ResponseOutputMessage.Content, + ) { + private val outputText by lazy { + rawContent.outputText().map { responseTypeFromJson(it.text(), responseType) } + } + + /** + * Gets the output text, but deserialized to an instance of the response type class. + * + * @see ResponseOutputMessage.Content.outputText + */ + fun outputText(): Optional = outputText + + /** @see ResponseOutputMessage.Content.refusal */ + fun refusal(): Optional = rawContent.refusal() + + /** @see ResponseOutputMessage.Content.isOutputText */ + // No need to check `outputText`; the delegate can just check the source value is present. + fun isOutputText(): Boolean = rawContent.isOutputText() + + /** @see ResponseOutputMessage.Content.isRefusal */ + fun isRefusal(): Boolean = rawContent.isRefusal() + + /** @see ResponseOutputMessage.Content.asOutputText */ + fun asOutputText(): T = + outputText.getOrElse { + // Same behavior as `com.openai.core.getOrThrow` used by the delegate class. + throw OpenAIInvalidDataException("`outputText` is not present") + } + + /** @see ResponseOutputMessage.Content.asRefusal */ + fun asRefusal(): ResponseOutputRefusal = rawContent.asRefusal() + + /** @see ResponseOutputMessage.Content._json */ + fun _json(): Optional = rawContent._json() + + /** @see ResponseOutputMessage.Content.accept */ + fun accept(visitor: Visitor): R = + when { + outputText.isPresent -> visitor.visitOutputText(outputText.get()) + refusal().isPresent -> visitor.visitRefusal(refusal().get()) + else -> visitor.unknown(_json().getOrNull()) + } + + /** @see ResponseOutputMessage.Content.validate */ + fun validate(): Content = apply { + // The `outputText` object, as it is a user-defined type that is unlikely to have a + // `validate()` function/method, so validate the underlying `ResponseOutputText` from + // which it is derived. That can be done by the delegate class. + rawContent.validate() + } + + /** @see ResponseOutputMessage.Content.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is Content<*> && + rawContent == other.rawContent && + responseType == other.responseType + } + + override fun hashCode(): Int = Objects.hash(rawContent, responseType) + + override fun toString(): String = + "${javaClass.simpleName}{responseType=$responseType, rawContent=$rawContent}" + + /** @see ResponseOutputMessage.Content.Visitor */ + interface Visitor { + /** @see ResponseOutputMessage.Content.Visitor.visitOutputText */ + fun visitOutputText(outputText: T): R + + /** @see ResponseOutputMessage.Content.Visitor.visitRefusal */ + fun visitRefusal(refusal: ResponseOutputRefusal): R + + /** @see ResponseOutputMessage.Content.Visitor.unknown */ + fun unknown(json: JsonValue?): R { + throw OpenAIInvalidDataException("Unknown Content: $json") + } + } + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredResponseOutputMessage<*> && + responseType == other.responseType && + rawMessage == other.rawMessage + } + + private val hashCode: Int by lazy { Objects.hash(responseType, rawMessage) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseType=$responseType, rawMessage=$rawMessage}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt index b6e1f962..12ed80dd 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/ResponseService.kt @@ -12,6 +12,8 @@ import com.openai.models.responses.ResponseCreateParams import com.openai.models.responses.ResponseDeleteParams import com.openai.models.responses.ResponseRetrieveParams import com.openai.models.responses.ResponseStreamEvent +import com.openai.models.responses.StructuredResponse +import com.openai.models.responses.StructuredResponseCreateParams import com.openai.services.blocking.responses.InputItemService interface ResponseService { @@ -42,6 +44,27 @@ interface ResponseService { requestOptions: RequestOptions = RequestOptions.none(), ): Response + /** + * Creates a model response. The model's structured output in JSON form will be deserialized + * automatically into an instance of the class `T`. See the SDK documentation for more details. + * + * @see create + */ + fun create(params: StructuredResponseCreateParams): StructuredResponse = + create(params, RequestOptions.none()) + + /** + * Creates a model response. The model's structured output in JSON form will be deserialized + * automatically into an instance of the class `T`. See the SDK documentation for more details. + * + * @see create + */ + fun create( + params: StructuredResponseCreateParams, + requestOptions: RequestOptions = RequestOptions.none(), + ): StructuredResponse = + StructuredResponse(params.responseType, create(params.rawParams, requestOptions)) + /** * Creates a model response. Provide [text](https://platform.openai.com/docs/guides/text) or * [image](https://platform.openai.com/docs/guides/images) inputs to generate diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 8febcf12..43f372d2 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -15,6 +15,8 @@ import com.openai.models.chat.completions.ChatCompletionListPage import com.openai.models.chat.completions.ChatCompletionListParams import com.openai.models.chat.completions.ChatCompletionRetrieveParams import com.openai.models.chat.completions.ChatCompletionUpdateParams +import com.openai.models.chat.completions.StructuredChatCompletion +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams import com.openai.services.blocking.chat.completions.MessageService interface ChatCompletionService { @@ -53,6 +55,30 @@ interface ChatCompletionService { requestOptions: RequestOptions = RequestOptions.none(), ): ChatCompletion + /** + * Creates a model response for the given chat conversation. The model's structured output in + * JSON form will be deserialized automatically into an instance of the class `T`. See the SDK + * documentation for more details. + * + * @see create + */ + fun create( + params: StructuredChatCompletionCreateParams + ): StructuredChatCompletion = create(params, RequestOptions.none()) + + /** + * Creates a model response for the given chat conversation. The model's structured output in + * JSON form will be deserialized automatically into an instance of the class `T`. See the SDK + * documentation for more details. + * + * @see create + */ + fun create( + params: StructuredChatCompletionCreateParams, + requestOptions: RequestOptions = RequestOptions.none(), + ): StructuredChatCompletion = + StructuredChatCompletion(params.responseType, create(params.rawParams, requestOptions)) + /** * **Starting a new project?** We recommend trying * [Responses](https://platform.openai.com/docs/api-reference/responses) to take advantage of diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt new file mode 100644 index 00000000..dd5cdd57 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -0,0 +1,1520 @@ +package com.openai.core + +import com.fasterxml.jackson.annotation.JsonClassDescription +import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.annotation.JsonPropertyDescription +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode +import com.openai.errors.OpenAIInvalidDataException +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatNoException +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.AfterTestExecutionCallback +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.RegisterExtension + +/** Tests for the `StructuredOutputs` functions and the [JsonSchemaValidator]. */ +internal class StructuredOutputsTest { + companion object { + private const val SCHEMA = "\$schema" + private const val SCHEMA_VER = "https://json-schema.org/draft/2020-12/schema" + private const val DEFS = "\$defs" + private const val REF = "\$ref" + + /** + * `true` to print the schema and validation errors for all executed tests, or `false` to + * print them only for failed tests. + */ + private const val VERBOSE_MODE = false + + private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) + } + + /** + * A validator that can be used by each unit test. A new validation instance is created for each + * test, as each test is run from its own instance of the test class. If a test fails, any + * validation errors are automatically printed to standard output to aid diagnosis. + */ + val validator = JsonSchemaValidator.create() + + /** + * The schema that was created by the unit test. This may be printed out after a test fails to + * aid in diagnosing the cause of the failure. In that case, this property must be set, or an + * error will occur. However, it will only be printed if the failed test method has the name + * prefix `schemaTest_`, so only test methods with that naming pattern need to set this field. + */ + lateinit var schema: JsonNode + + /** + * An extension to JUnit that prints the [schema] and the validation status (including any + * errors) when a test fails. This applies only to test methods whose names are prefixed with + * `schemaTest_`. An error will occur if [schema] was not set, but this can be avoided by only + * using the method name prefix for test methods that set [schema]. This reporting is intended + * as an aid to diagnosing test failures. + */ + @Suppress("unused") + @RegisterExtension + val printValidationErrorsOnFailure: AfterTestExecutionCallback = + object : AfterTestExecutionCallback { + @Throws(Exception::class) + override fun afterTestExecution(context: ExtensionContext) { + if ( + context.displayName.startsWith("schemaTest_") && + (VERBOSE_MODE || context.executionException.isPresent) + ) { + // Test failed. + println("Schema: ${schema.toPrettyString()}\n") + println("$validator\n") + } + } + } + + // NOTE: In most of these tests, it is assumed that the schema is generated as expected; it is + // not examined in fine detail if the validator succeeds or fails with the expected errors. + + @Test + fun schemaTest_minimalSchema() { + class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_minimalListSchema() { + val s: List = listOf() + + schema = extractSchema(s.javaClass) + validator.validate(schema) + + // Currently, the generated schema looks like this: + // + // { + // "$schema" : "https://json-schema.org/draft/2020-12/schema", + // "type" : "array", + // "items" : { } + // } + // + // That causes an error, as the `"items"` object is empty when it should be a valid + // sub-schema. Something like this is what would be valid: + // + // { + // "$schema" : "https://json-schema.org/draft/2020-12/schema", + // "type" : "array", + // "items" : { + // "type" : "string" + // } + // } + // + // The reason for the failure is that generic type information is erased for scopes like + // local variables, but generic type information for fields is retained as part of the class + // metadata. This is the expected behavior in Java, so this test expects an invalid schema. + assertThat(validator.isValid()).isFalse + assertThat(validator.errors()).hasSize(2) + assertThat(validator.errors()[0]).isEqualTo("#/items: Schema or sub-schema is empty.") + assertThat(validator.errors()[1]) + .isEqualTo("#/items: Expected exactly one of 'type' or 'anyOf' or '$REF'.") + } + + @Test + fun schemaTest_listFieldSchema() { + @Suppress("unused") class X(val s: List) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + // This gives a root schema with `"type" : "string"` and `"const" : "HELLO"` + // Unfortunately, an "enum class" cannot be defined within a function or within a class within + // a function. + @Suppress("unused") + enum class MinimalEnum1 { + HELLO + } + + @Test + fun schemaTest_minimalEnumSchema1() { + schema = extractSchema(MinimalEnum1::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + // This gives a root schema with `"type" : "string"` and `"enum" : [ "HELLO", "WORLD" ]` + @Suppress("unused") + enum class MinimalEnum2 { + HELLO, + WORLD, + } + + @Test + fun schemaTest_minimalEnumSchema2() { + schema = extractSchema(MinimalEnum2::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_nonStringEnum() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "integer", + "enum" : [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchema() { + @Suppress("unused") class X(val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalString() { + // Using an `Optional` will result in this JSON: `"type" : [ "string", "null" ]`. + // That is supported by the OpenAI Structured Outputs API spec, as long as the field is also + // marked as required. Though required, it is still allowed for the field to be explicitly + // set to `"null"`. + @Suppress("unused") class X(val s: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalBoolean() { + @Suppress("unused") class X(val b: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalInteger() { + @Suppress("unused") class X(val i: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalNumber() { + @Suppress("unused") class X(val n: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_arraySchemaFromOptional() { + @Suppress("unused") class X(val s: Optional>) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_arrayTypeMissingItems() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "array" + } + """ + ) + validator.validate(schema) + + // Check once here that "validator.isValid()" returns "false" when there is an error. In + // the other tests, there is no need to repeat this assertion, as it would be redundant. + assertThat(validator.isValid()).isFalse + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'items' field is missing or is not an object.") + } + + @Test + fun schemaTest_arrayTypeWithWrongItemsType() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "array", + "items" : [ "should_not_be_an_array" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'items' field is missing or is not an object.") + } + + @Test + @Suppress("unused") + fun schemaTest_objectSubSchemaFromOptional() { + class X(val s: Optional) + class Y(val x: Optional) + + schema = extractSchema(Y::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_badOptionalTypeNotArray() { + // Testing more for code coverage than for anything expected to go wrong in practice. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : { "type" : "string" } + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'type' field is not a type name or array of type names.") + } + + @Test + fun schemaTest_badOptionalTypeNoNull1() { + // Testing more for code coverage than for anything expected to go wrong in practice. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeNoNull2() { + // If "type" is an array, one of the two "type" values must be "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", "number" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected one type name and one \"null\".") + } + + @Test + fun schemaTest_badOptionalTypeNoNull3() { + // If "type" is an array, there must be two type values only, one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", "number", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeNoStringTypeNames() { + // If "type" is an array, there must be two type values only, one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", null ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeAllNull() { + // If "type" is an array, there must be two type values only, and only one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "null", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected one type name and one \"null\".") + } + + @Test + fun schemaTest_badOptionalTypeUnknown() { + // If "type" is an array, there must be two type values only, and only one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "unknown", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#/type: Unsupported 'type' value: 'unknown'.") + } + + @Test + fun schemaTest_goodOptionalTypeNullFirst() { + // The validator should be lenient about the order of the null/not-null types in the array. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "null", "string" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinyRecursiveSchema() { + @Suppress("unused") class X(val s: String, val x: X?) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_unsupportedKeywords() { + // OpenAI lists a set of keywords that are not allowed, but the set is not exhaustive. Check + // that everything named in that set is identified as not allowed, as that is the minimum + // level of validation expected. Check at the root schema and a sub-schema. There is no need + // to match the keywords to their expected schema types or be concerned about the values of + // the keyword fields, which makes testing easier. + val keywordsNotAllowed = + listOf( + "minLength", + "maxLength", + "pattern", + "format", + "minimum", + "maximum", + "multipleOf", + "patternProperties", + "unevaluatedProperties", + "propertyNames", + "minProperties", + "maxProperties", + "unevaluatedItems", + "contains", + "minContains", + "maxContains", + "minItems", + "maxItems", + "uniqueItems", + ) + val notAllowedUses = keywordsNotAllowed.joinToString(", ") { "\"$it\" : \"\"" } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "x" : { + "type" : "string", + $notAllowedUses + } + }, + $notAllowedUses, + "additionalProperties" : false, + "required" : [ "x" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(keywordsNotAllowed.size * 2) + keywordsNotAllowed.forEachIndexed { index, keyword -> + assertThat(validator.errors()[index]) + .isEqualTo("#: Use of '$keyword' is not supported here.") + assertThat(validator.errors()[index + keywordsNotAllowed.size]) + .isEqualTo("#/properties/x: Use of '$keyword' is not supported here.") + } + } + + @Test + fun schemaTest_propertyNotMarkedRequired() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayNull() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : null + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_additionalPropertiesMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'additionalProperties' field is missing or is not set to 'false'.") + } + + @Test + fun schemaTest_additionalPropertiesTrue() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : true, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'additionalProperties' field is missing or is not set to 'false'.") + } + + @Test + fun schemaTest_objectPropertiesMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + // For now, allow that an object may have no properties. Update this if that is revised. + assertThat(validator.isValid()).isTrue() + } + + @Test + fun schemaTest_objectPropertiesNotObject() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : [ "name", "age" ], + "additionalProperties" : false, + "required" : [ "name", "age" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is not a non-empty object.") + } + + @Test + fun schemaTest_objectPropertiesEmpty() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { }, + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is not a non-empty object.") + } + + @Test + fun schemaTest_anyOfInRootSchema() { + // OpenAI does not allow `"anyOf"` to appear at the root level of a schema. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "anyOf" : [ { + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : ["name"] + }, { + "type" : "array", + "items" : { + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : ["name"] + } + } ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#: Root schema contains 'anyOf' field.") + } + + @Test + fun schemaTest_anyOfNotArray() { + // Unlikely that this can occur in a generated schema, so this is more about code coverage. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { + "anyOf" : { + "type" : "string" + } + } + }, + "additionalProperties" : false, + "required" : ["name"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/name: 'anyOf' field is not a non-empty array.") + } + + @Test + fun schemaTest_anyOfIsEmptyArray() { + // Unlikely that this can occur in a generated schema, so this is more about code coverage. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { + "anyOf" : [ ] + } + }, + "additionalProperties" : false, + "required" : ["name"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/name: 'anyOf' field is not a non-empty array.") + } + + @Test + fun schemaTest_anyOfInSubSchemaArray() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "value" : { + "anyOf" : [ + { "type" : "string" }, + { "type" : "number" } + ] + } + }, + "additionalProperties" : false, + "required" : ["value"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_noSchemaFieldRootSchema() { + @Suppress("unused") class X(val s: String) + + schema = extractSchema(X::class.java) + (schema as ObjectNode).remove(SCHEMA) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#: Root schema missing '$SCHEMA' field.") + } + + @Test + @Suppress("unused") + fun schemaTest_deepNestingAtLimit() { + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + + schema = extractSchema(Y::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + @Suppress("unused") + fun schemaTest_deepNestingBeyondLimit() { + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + schema = extractSchema(Z::class.java) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).contains("Current nesting depth is 6, but maximum is 5.") + } + + @Test + fun schemaTest_stringEnum250ValueOverSizeLimit() { + // OpenAI specification: "For a single enum property with string values, the total string + // length of all enum values cannot exceed 7,500 characters when there are more than 250 + // enum values." + + // This test creates an enum with exactly 250 string values with more than 7,500 characters + // in total (31 characters per value for a total of 7,750 characters). No error is expected. + val values = (1..250).joinToString(", ") { "\"%s%03d\"".format("x".repeat(28), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_stringEnum251ValueUnderSizeLimit() { + // This test creates an enum with exactly 251 string values with fewer than 7,500 characters + // in total (29 characters per value for a total of 7,279 characters). No error is expected. + val values = (1..251).joinToString(", ") { "\"%s%03d\"".format("x".repeat(26), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_stringEnum251ValueOverSizeLimit() { + // This test creates an enum with exactly 251 string values with fewer than 7,500 characters + // in total (30 characters per value for a total of 7,530 characters). An error is expected. + val values = (1..251).joinToString(", ") { "\"%s%03d\"".format("x".repeat(27), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo( + "#/enum: Total string length (7530) of values of an enum " + + "with 251 values exceeds limit of 7500." + ) + } + + @Test + fun schemaTest_totalEnumValuesAtLimit() { + // OpenAI specification: "A schema may have up to 500 enum values across all enum + // properties." + + // This test creates two enums with a total of 500 values. The total string length of the + // values is well within the limits (2,000 characters). + val valuesA = (1..250).joinToString(", ") { "\"a%03d\"".format(it) } + val valuesB = (1..250).joinToString(", ") { "\"b%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "string", + "enum" : [ $valuesA ] + }, + "b" : { + "type" : "string", + "enum" : [ $valuesB ] + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_totalEnumValuesOverLimit() { + // This test creates two enums with a total of 501 values. The total string length of the + // values is well within the limits (2,004 characters). + val valuesA = (1..250).joinToString(", ") { "\"a%03d\"".format(it) } + val valuesB = (1..251).joinToString(", ") { "\"b%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "string", + "enum" : [ $valuesA ] + }, + "b" : { + "type" : "string", + "enum" : [ $valuesB ] + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total number of enum values (501) exceeds limit of 500.") + } + + @Test + fun schemaTest_maxObjectPropertiesAtLimit() { + // This test creates two object schemas with a total of 100 object properties. OpenAI does + // not support more than 100 properties total in the whole schema. Two objects are used to + // ensure that counting is not done per object, but across all objects. Note that each + // object schema is itself a property, so there are two properties at the top level and 49 + // properties each at the next level. No error is expected, as the limit is not exceeded. + val propUses = + (1..49).joinToString(", ") { "\"x%02d\" : { \"type\" : \"string\" }".format(it) } + val propNames = (1..49).joinToString(", ") { "\"x%02d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + }, + "b" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_maxObjectPropertiesOverLimit() { + // This test creates two object schemas with a total of 101 object properties. OpenAI does + // not support more than 100 properties total in the whole schema. Expect an error. + val propUses = + (1..49).joinToString(", ") { "\"x_%02d\" : { \"type\" : \"string\" }".format(it) } + val propNames = (1..49).joinToString(", ") { "\"x_%02d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + }, + "b" : { + "type" : "object", + "properties" : { + $propUses, + "property_101" : { "type" : "string" } + }, + "required" : [ $propNames, "property_101" ], + "additionalProperties" : false + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total number of object properties (101) exceeds limit of 100.") + } + + @Test + fun schemaTest_maxStringLengthAtLimit() { + // OpenAI specification: "In a schema, total string length of all property names, definition + // names, enum values, and const values cannot exceed 15,000 characters." + // + // This test creates a schema with many property names, definition names, enum values, and + // const values calculated to have a total string length of 15,000 characters. No error is + // expected. + // + // The test creates a schema that looks like the following, with the numbers adjusted to + // achieve a total of 15,000 characters for the relevant elements. + // + // { + // "$schema" : "...", + // "$defs" : { + // "d_001" : { + // "type" : "string", + // "const" : "c_001" + // }, + // ..., + // "d_nnn" : { + // "type" : "string", + // "const" : "c_nnn" + // } + // }, + // "type" : "object", + // "properties" : { + // "p_001" : { + // "type" : "string", + // "enum" : [ "eeeee..._001", ..., "eeeee..._nnn" ] + // }, + // ..., + // "p_nnn" : { + // "type" : "string", + // "enum" : [ "eeeee..._001", ..., "eeeee..._nnn" ] + // } + // }, + // "required" : [ "p_001", ..., "p_nnn" ], + // "additionalProperties" : false + // } + + val numDefs = 65 // Each also has one "const" value. + val numProps = 70 // Each also has "numEnumValues" enum values. + val nameLen = 5 // Length of names of definitions, properties and const values. + val numEnumValues = 5 // numProps * numEnumValues <= 500 limit (OpenAI) + val enumValueLen = 40 // Length of enum values. + val expectedTotalStringLength = + nameLen * (numProps + numDefs * 2) + numProps * enumValueLen * numEnumValues + + val enumValues = + (1..numEnumValues).joinToString(", ") { "\"%s_%03d\"".format("e".repeat(36), it) } + val defs = + (1..numDefs).joinToString(", ") { + "\"d_%03d\" : { \"type\" : \"string\", \"const\" : \"c_%03d\" }".format(it, it) + } + val props = + (1..numProps).joinToString(", ") { + "\"p_%03d\" : { \"type\" : \"string\", \"enum\" : [ $enumValues ] }".format(it) + } + val propNames = (1..numProps).joinToString(", ") { "\"p_%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { $defs }, + "type" : "object", + "properties" : { $props }, + "required" : [ $propNames ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(expectedTotalStringLength).isEqualTo(15_000) // Exactly on the limit. + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_maxStringLengthOverLimit() { + // OpenAI specification: "In a schema, total string length of all property names, definition + // names, enum values, and const values cannot exceed 15,000 characters." + // + // This test creates a schema with many property names, definition names, enum values, and + // const values calculated to have a total string length of just over 15,000 characters. An + // error is expected. + + val numDefs = 66 // Each also has one "const" value. + val numProps = 70 // Each also has "numEnumValues" enum values. + val numEnumValues = 5 // numProps * numEnumValues <= 500 limit (OpenAI) + val nameLen = 5 // Length of names of definitions, properties and const values. + val enumValueLen = 40 // Length of enum values. + val expectedTotalStringLength = + nameLen * (numProps + numDefs * 2) + numProps * enumValueLen * numEnumValues + + val enumValues = + (1..numEnumValues).joinToString(", ") { "\"%s_%03d\"".format("e".repeat(36), it) } + val defs = + (1..numDefs).joinToString(", ") { + "\"d_%03d\" : { \"type\" : \"string\", \"const\" : \"c_%03d\" }".format(it, it) + } + val props = + (1..numProps).joinToString(", ") { + "\"p_%03d\" : { \"type\" : \"string\", \"enum\" : [ $enumValues ] }".format(it) + } + val propNames = (1..numProps).joinToString(", ") { "\"p_%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { $defs }, + "type" : "object", + "properties" : { $props }, + "required" : [ $propNames ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(expectedTotalStringLength).isGreaterThan(15_000) + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total string length of all values (15010) exceeds limit of 15000.") + } + + @Test + fun schemaTest_annotatedWithJsonClassDescription() { + // Add a "description" to the root schema using an annotation. + @JsonClassDescription("A simple schema.") class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val desc = schema.get("description") + + assertThat(validator.isValid()).isTrue + assertThat(desc).isNotNull + assertThat(desc.isTextual).isTrue + assertThat(desc.asText()).isEqualTo("A simple schema.") + } + + @Test + fun schemaTest_annotatedWithJsonPropertyDescription() { + // Add a "description" to the property using an annotation. + @Suppress("unused") class X(@get:JsonPropertyDescription("A string value.") val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("s") + val desc = stringProperty.get("description") + + assertThat(validator.isValid()).isTrue + assertThat(desc).isNotNull + assertThat(desc.isTextual).isTrue + assertThat(desc.asText()).isEqualTo("A string value.") + } + + @Test + fun schemaTest_annotatedWithJsonProperty() { + // Override the default name of the property using the annotation. + @Suppress("unused") class X(@get:JsonProperty("a_string") val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("a_string") + + assertThat(validator.isValid()).isTrue + assertThat(stringProperty).isNotNull + } + + @Test + fun schemaTest_annotatedWithJsonPropertyRejectDefaultValue() { + // Set a default value for the property. It should be ignored when the schema is generated, + // as default property values are not supported in OpenAI JSON schemas. (The Victools docs + // have examples of how to add support for this default values via annotations or initial + // values, should support for default values be needed in the future.) + // + // Lack of support is not mentioned in the specification, but see the evidence at: + // https://engineering.fractional.ai/openai-structured-output-fixes + @Suppress("unused") + class X( + @get:JsonProperty(defaultValue = "default_value_1") val s: String = "default_value_2" + ) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("s") + + assertThat(validator.isValid()).isTrue + assertThat(stringProperty).isNotNull + assertThat(stringProperty.get("default")).isNull() + } + + @Test + fun schemaTest_annotatedWithJsonIgnore() { + // Override the default name of the property using the annotation. + @Suppress("unused") class X(@get:JsonIgnore val s1: String, val s2: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val s1Property = properties.get("s1") + val s2Property = properties.get("s2") + + assertThat(validator.isValid()).isTrue + assertThat(s1Property).isNull() + assertThat(s2Property).isNotNull + } + + @Test + fun schemaTest_emptyDefinitions() { + // Be lenient about empty definitions. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "type" : "string" + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_referenceMissingReferent() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "$REF" : "#/$DEFS/Person" + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/$REF: Invalid or unsupported reference: '#/$DEFS/Person'.") + } + + @Test + fun schemaTest_referenceFieldIsNotTextual() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "$REF" : 42 + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#/$REF: '$REF' field is not a text value.") + } + + @Test + fun validatorBeforeValidation() { + assertThat(validator.errors()).isEmpty() + assertThat(validator.isValid()).isFalse + } + + @Test + fun validatorReused() { + class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Should fail if an attempt is made to reuse the validator. + assertThatThrownBy { validator.validate(schema) } + .isExactlyInstanceOf(IllegalStateException::class.java) + .hasMessageContaining("Validation already complete.") + } + + @Test + @Suppress("unused") + fun schemaTest_largeLaureatesSchema() { + // This covers many cases: large and complex "$defs", resolution of references, recursive + // references, etc. The output is assumed to be good (it has been checked by eye) and the + // test just shows that the validator can handle the complexity without crashing or emitting + // spurious errors. + class Name(val givenName: String, val familyName: String) + + class Person( + @get:JsonPropertyDescription("The name of the person.") val name: Name, + @get:JsonProperty(value = "date_of_birth", defaultValue = "unknown_1") + @get:JsonPropertyDescription("The date of birth of the person.") + var dateOfBirth: String, + @get:JsonPropertyDescription("The country of citizenship of the person.") + var nationality: String, + // A child being a `Person` results in a recursive schema. + @get:JsonPropertyDescription("The children (if any) of the person.") + val children: List, + ) { + @get:JsonPropertyDescription("The other name of the person.") + var otherName: Name = Name("Bob", "Smith") + } + + class Laureate( + val laureate: Person, + val majorContribution: String, + val yearOfWinning: String, + @get:JsonIgnore val favoriteColor: String, + ) + + class Laureates( + // Two lists results in a `Laureate` definition that is referenced in the schema. + var laureates1901to1950: List, + var laureates1951to2025: List, + ) + + schema = extractSchema(Laureates::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun fromJsonSuccess() { + @Suppress("unused") class X(val s: String) + + val x = responseTypeFromJson("{\"s\" : \"hello\"}", X::class.java) + + assertThat(x.s).isEqualTo("hello") + } + + @Test + fun fromJsonFailure1() { + @Suppress("unused") class X(val s: String) + + // Well-formed JSON, but it does not match the schema of class `X`. + assertThatThrownBy { responseTypeFromJson("{\"wrong\" : \"hello\"}", X::class.java) } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("Error parsing JSON: {\"wrong\" : \"hello\"}") + } + + @Test + fun fromJsonFailure2() { + @Suppress("unused") class X(val s: String) + + // Malformed JSON. + assertThatThrownBy { responseTypeFromJson("{\"truncated", X::class.java) } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("Error parsing JSON: {\"truncated") + } + + @Test + fun fromClassEnablesStrictAdherenceToSchema() { + @Suppress("unused") class X(val s: String) + + val jsonSchema = responseFormatFromClass(X::class.java) + + // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform + // to the JSON schema. + assertThat(jsonSchema.jsonSchema().strict()).isPresent + assertThat(jsonSchema.jsonSchema().strict().get()).isTrue + } + + @Test + @Suppress("unused") + fun fromClassSuccessWithoutValidation() { + // Exceed the maximum nesting depth, but do not enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatNoException().isThrownBy { + responseFormatFromClass(Z::class.java, JsonSchemaLocalValidation.NO) + } + } + + @Test + fun fromClassSuccessWithValidation() { + @Suppress("unused") class X(val s: String) + + assertThatNoException().isThrownBy { + responseFormatFromClass(X::class.java, JsonSchemaLocalValidation.YES) + } + } + + @Test + @Suppress("unused") + fun fromClassFailureWithValidation() { + // Exceed the maximum nesting depth and enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { responseFormatFromClass(Z::class.java, JsonSchemaLocalValidation.YES) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + @Suppress("unused") + fun fromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `true` by expecting + // a validation error when that argument is not given an explicit value. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + // Use default for `localValidation` flag. + assertThatThrownBy { responseFormatFromClass(Z::class.java) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt new file mode 100644 index 00000000..c62eb8ed --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt @@ -0,0 +1,366 @@ +package com.openai.core + +import java.lang.reflect.Method +import java.util.Optional +import kotlin.reflect.KClass +import kotlin.reflect.KFunction +import kotlin.reflect.KVisibility +import kotlin.reflect.full.declaredFunctions +import kotlin.reflect.jvm.javaMethod +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.fail +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +// Constants for values that can be used in many of the tests as sample input or output values. +// +// Where a function returns `Optional`, `JsonField` or `JsonValue` There is no need to provide +// a value that matches the type ``, a simple `String` value of `"a-string"` will work OK. +internal const val STRING = "a-string" +internal val NULLABLE_STRING: String? = null +internal val OPTIONAL = Optional.of(STRING) +internal val JSON_FIELD = JsonField.of(STRING) +internal val JSON_VALUE = JsonValue.from(STRING) +internal val NULLABLE = null +internal const val BOOLEAN: Boolean = true +internal val NULLABLE_BOOLEAN: Boolean? = null +internal const val LONG: Long = 42L +internal val NULLABLE_LONG: Long? = null +internal const val DOUBLE: Double = 42.0 +internal val NULLABLE_DOUBLE: Double? = null +internal val LIST = listOf(STRING) +internal val SET = setOf(STRING) +internal val MAP = mapOf(STRING to STRING) + +/** + * Defines a test case where a function in a delegator returns a value from a corresponding function + * in a delegate. + */ +internal data class DelegationReadTestCase(val functionName: String, val expectedValue: Any) + +/** + * Defines a test case where a function in a delegator passes its parameters through to a + * corresponding function in a delegate. + */ +// Want `vararg`, so cannot use `data class`. Needs a custom `toString`, anyway. +internal class DelegationWriteTestCase( + val functionName: String, + /** + * The values to pass to the function being tested. If the first input value is `null`, it must + * be the only value. Only the first input value may be `null`, all others must be non-`null`. + * This is not enforced by this class, but is assumed by the related utility functions. + */ + vararg val inputValues: Any?, +) { + /** Gets the string representation that identifies the test function when running JUnit. */ + override fun toString(): String = + "$functionName(${inputValues.joinToString(", ") { + it?.javaClass?.simpleName ?: "null" + }})" +} + +/** A basic class used as the generic type when testing. */ +internal class X(val s: String) { + override fun equals(other: Any?) = other is X && other.s == s + + override fun hashCode() = s.hashCode() +} + +/** + * Checks that all functions in one class have a corresponding function with the same name and + * parameter types in another class. A list of function names that should be allowed as exceptions + * can be given. Non-public functions are ignored, as they are considered to be implementation + * details of each class. + * + * Call this function twice, changing the order of the two classes to ensure that both classes + * contain the same set of functions (barring exceptions), should that be the expectation. + * + * @param subsetClass The class whose functions should be a subset of the functions of the other + * class. + * @param supersetClass The class whose functions should be a superset of the functions of the other + * class. + */ +internal fun checkAllDelegation( + subsetClass: KClass<*>, + supersetClass: KClass<*>, + vararg exceptFunctionNames: String, +) { + assertThat(subsetClass != supersetClass) + .describedAs { "The two classes should not be the same." } + .isTrue + + val subsetFunctions = subsetClass.declaredFunctions + val missingFunctions = mutableListOf>() + + for (subsetFunction in subsetFunctions) { + if (subsetFunction.visibility != KVisibility.PUBLIC) { + continue + } + + if (subsetFunction.name in exceptFunctionNames) { + continue + } + + // Drop the first parameter from each function, as it is the implicit "this" object and has + // the type of the class declaring the function, which will never match. + val supersetFunction = + supersetClass.declaredFunctions.find { + it.name == subsetFunction.name && + it.parameters.drop(1).map { it.type } == + subsetFunction.parameters.drop(1).map { it.type } + } + + if (supersetFunction == null) { + missingFunctions.add(subsetFunction) + } + } + + assertThat(missingFunctions) + .describedAs { + "Function(s) not found in ${supersetClass.simpleName}:\n" + + missingFunctions.joinToString("\n") { " - $it" } + } + .isEmpty() +} + +/** + * Checks that the delegator function calls the corresponding delegate function and no other + * functions on the delegate. The test case defines the function name and the sample return value. + * All functions take no arguments. + */ +internal fun checkOneDelegationRead( + delegator: Any, + mockDelegate: Any, + testCase: DelegationReadTestCase, +) { + // Stub the method in the mock delegate using reflection + val delegateMethod = mockDelegate::class.java.getMethod(testCase.functionName) + `when`(delegateMethod.invoke(mockDelegate)).thenReturn(testCase.expectedValue) + + // Call the corresponding method on the delegator using reflection + val delegatorMethod = delegator::class.java.getMethod(testCase.functionName) + val result = delegatorMethod.invoke(delegator) + + // Verify that the corresponding method on the mock delegate was called exactly once + verify(mockDelegate, times(1)).apply { delegateMethod.invoke(mockDelegate) } + verifyNoMoreInteractions(mockDelegate) + + // Assert that the result matches the expected value + assertThat(result).isEqualTo(testCase.expectedValue) +} + +/** + * Checks that the delegator function calls the corresponding delegate function and no other + * functions on the delegate. The test case defines the function name and sample parameter values. + */ +internal fun checkOneDelegationWrite( + delegator: Any, + mockDelegate: Any, + testCase: DelegationWriteTestCase, +) { + invokeMethod(findDelegationMethod(delegator, testCase), delegator, testCase) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockDelegate, times(1)).apply { + invokeMethod(findDelegationMethod(mockDelegate, testCase), mockDelegate, testCase) + } + verifyNoMoreInteractions(mockDelegate) +} + +private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteTestCase) { + val numParams = testCase.inputValues.size + val inputValue1 = testCase.inputValues[0] + val inputValue2 = testCase.inputValues.getOrNull(1) + + when (numParams) { + 1 -> method.invoke(target, inputValue1) + 2 -> method.invoke(target, inputValue1, inputValue2) + else -> fail { "Unexpected number of function parameters ($numParams)." } + } +} + +/** + * Finds the java method matching the test case's function name and parameter types in the delegator + * or delegate `target`. + */ +internal fun findDelegationMethod(target: Any, testCase: DelegationWriteTestCase): Method { + val numParams = testCase.inputValues.size + val inputValue1: Any? = testCase.inputValues[0] + val inputValue2 = if (numParams > 1) testCase.inputValues[1] else null + + val method = + when (numParams) { + 1 -> + if (inputValue1 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + ) + } else { + // Only the first parameter may be nullable and only if it is the only + // parameter. If the first parameter is nullable, it will be the only function + // of the same name with a nullable first parameter. To handle the potentially + // nullable first parameter, Kotlin reflection is needed. This allows a function + // `f(Boolean)` to be distinguished from `f(Boolean?)`. For the tests, if the + // parameter type is nullable, the parameter value will always be `null` (if + // not, the function with the nullable parameter would not be matched). + // + // Using Kotlin reflection, the first parameter (zero index) is `this` object, + // so start matching from the second parameter onwards. + target::class + .declaredFunctions + .find { + it.name == testCase.functionName && + it.parameters[1].type.isMarkedNullable + } + ?.javaMethod + } + + 2 -> + if (inputValue1 != null && inputValue2 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + toJavaType(inputValue2.javaClass), + ) + } else { + // There are no instances where there are two parameters and one of them is + // nullable. + fail { "Function $testCase second parameter must not be null." } + } + + else -> fail { "Function $testCase has unsupported number of parameters." } + } ?: fail { "Function $testCase cannot be found in $target." } + + // Using `fail` conditionally above, so the compiler knows the code will not continue and can + // infer that `method` is not null. It cannot do that for `assertThat...isNotNull`. + return method +} + +/** Finds a Java method in a class that matches a method name and a list of parameter types. */ +private fun findJavaMethod( + clazz: Class<*>, + methodName: String, + vararg parameterTypes: Class<*>, +): Method? = + clazz.declaredMethods.firstOrNull { method -> + method.name == methodName && + method.parameterTypes.size == parameterTypes.size && + method.parameterTypes.indices.all { index -> + (parameterTypes[index].isPrimitive && + method.parameterTypes[index] == parameterTypes[index]) || + method.parameterTypes[index].isAssignableFrom(parameterTypes[index]) + } + } + +/** + * Returns the Java type to use when matching type parameters for a Java method. The type is the + * type of the input value that will be used when the method is invoked. For most types, the given + * type is returned. However, if the type represents a Kotlin primitive, it will be converted to a + * Java primitive. This allows matching of methods with parameter types that are non-nullable Kotlin + * primitives. If not translated, methods with parameter types that are nullable Kotlin primitives + * would always be matched instead. + */ +private fun toJavaType(type: Class<*>) = + when (type) { + // This only needs to cover the types used in the test cases. + java.lang.Long::class.java -> java.lang.Long.TYPE + java.lang.Boolean::class.java -> java.lang.Boolean.TYPE + java.lang.Double::class.java -> java.lang.Double.TYPE + else -> type + } + +/** + * Checks that all delegating functions in a delegator class have corresponding unit tests. The + * read-only functions should take no parameters; only return a value. + * + * @param delegatorClass The delegator class whose functions are tested. Every named function in + * this class must be identified in one of the given sources of function names or a failure will + * occur. + * @param delegationTestCases The tests cases that identify the names of delegating functions for + * which parameterized unit tests have been defined. + * @param exceptionalTestedFns The names of delegating functions that are tested separately, not as + * parameterized unit tests. This is usually because they require special handling in the test. + * @param nonDelegatingFns The names of functions that do not perform any delegation and for which + * delegation tests are not required. + */ +internal fun checkAllDelegatorReadFunctionsAreTested( + delegatorClass: KClass<*>, + delegationTestCases: List, + exceptionalTestedFns: Set, + nonDelegatingFns: Set, +) { + val testedFns = delegationTestCases.map { it.functionName }.toSet() + exceptionalTestedFns + val delegatorFunctions = delegatorClass.declaredFunctions + val untestedFunctions = + delegatorFunctions.filter { it.name !in testedFns && it.name !in nonDelegatingFns } + + assertThat(untestedFunctions) + .describedAs( + "Delegation is not tested for function(s):\n" + + untestedFunctions.joinToString("\n") { " - $it" } + ) + .isEmpty() +} + +/** + * Checks that all delegating functions in a delegator class have corresponding unit tests. The + * write-only functions should take parameters and return no value. + * + * @param delegatorClass The delegator class whose functions are tested. Every named function in + * this class must be identified in one of the given sources of function names or a failure will + * occur. + * @param delegationTestCases The tests cases that identify the names of delegating functions for + * which parameterized unit tests have been defined. + * @param exceptionalTestedFns The names of delegating functions that are tested separately, not as + * parameterized unit tests. This is usually because they require special handling in the test. + * @param nonDelegatingFns The names of functions that do not perform any delegation and for which + * delegation tests are not required. + */ +internal fun checkAllDelegatorWriteFunctionsAreTested( + delegatorClass: KClass<*>, + delegationTestCases: List, + exceptionalTestedFns: Set, + nonDelegatingFns: Set, +) { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. There are many overloaded functions, so the + // approach here is to build a list (_not_ a set) of all function names and then "subtract" + // those for which tests are defined and see what remains. For example, there could be eight + // `addMessage` functions, so there must be eight tests defined for functions named `addMessage` + // that will be subtracted from the list of functions matching that name. Parameter types are + // not checked, as that is awkward and probably overkill. Therefore, this scheme is not reliable + // if a function is tested more than once. + val testedFns = + (delegationTestCases.map { it.functionName } + exceptionalTestedFns).toMutableList() + // Only interested in the names of the functions (which may contain duplicates): parameters are + // not matched, so any signatures could be misleading when reporting errors. + val delegatorFns = delegatorClass.declaredFunctions.map { it.name }.toMutableList() + + // Making modifications to the list, so clone it with `toList()` before iterating. + for (fnName in delegatorFns.toList()) { + if (fnName in testedFns) { + testedFns.remove(fnName) + delegatorFns.remove(fnName) + } + if (fnName in nonDelegatingFns) { + delegatorFns.remove(fnName) + } + } + + // If there are function names remaining in `delegatorFns`, then there are tests missing. + assertThat(delegatorFns) + .describedAs { "Delegation is not tested for functions $delegatorFns." } + .isEmpty() + + // If there are function names remaining in `testedFns`, then there are more tests than there + // should be. Functions might be tested twice, or there may be tests for functions that have + // since been removed from the delegate (though those tests probably failed). + assertThat(testedFns) + .describedAs { "Unexpected or redundant tests for functions $testedFns." } + .isEmpty() +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt index e47aebd6..6ae11ba1 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt @@ -347,4 +347,36 @@ internal class ChatCompletionCreateParamsTest { ) assertThat(body.model()).isEqualTo(ChatModel.GPT_4_1) } + + @Test + fun structuredOutputsBuilder() { + class X(val s: String) + + // Only interested in a few things: + // - Does the `Builder` type change when `responseFormat(Class)` is called? + // - Are values already set on the "old" `Builder` preserved in the change-over? + // - Can new values be set on the "new" `Builder` alongside the "old" values? + val params = + ChatCompletionCreateParams.builder() + .addDeveloperMessage("dev message") + .model(ChatModel.GPT_4_1) + .responseFormat(X::class.java) // Creates and return a new builder. + .addSystemMessage("sys message") + .build() + + val body = params.rawParams._body() + + assertThat(params).isInstanceOf(StructuredChatCompletionCreateParams::class.java) + assertThat(params.responseType).isEqualTo(X::class.java) + assertThat(body.messages()) + .containsExactly( + ChatCompletionMessageParam.ofDeveloper( + ChatCompletionDeveloperMessageParam.builder().content("dev message").build() + ), + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content("sys message").build() + ), + ) + assertThat(body.model()).isEqualTo(ChatModel.GPT_4_1) + } } diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt new file mode 100644 index 00000000..4b289198 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt @@ -0,0 +1,318 @@ +package com.openai.models.chat.completions + +import com.openai.core.BOOLEAN +import com.openai.core.DOUBLE +import com.openai.core.DelegationWriteTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.LIST +import com.openai.core.LONG +import com.openai.core.MAP +import com.openai.core.NULLABLE +import com.openai.core.NULLABLE_BOOLEAN +import com.openai.core.NULLABLE_DOUBLE +import com.openai.core.NULLABLE_LONG +import com.openai.core.OPTIONAL +import com.openai.core.SET +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorWriteFunctionsAreTested +import com.openai.core.checkOneDelegationWrite +import com.openai.core.findDelegationMethod +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.core.responseFormatFromClass +import com.openai.models.ChatModel +import com.openai.models.FunctionDefinition +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletionCreateParams] class (delegator) and its delegation of + * most functions to a wrapped [ChatCompletionCreateParams] (delegate). It is the `Builder` class of + * each main class that is involved in the delegation. The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionCreateParamsTest { + companion object { + private val CHAT_MODEL = ChatModel.GPT_4 + + private val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + + private val USER_MESSAGE_PARAM = + ChatCompletionUserMessageParam.builder().content(STRING).build() + private val DEV_MESSAGE_PARAM = + ChatCompletionDeveloperMessageParam.builder().content(STRING).build() + private val SYS_MESSAGE_PARAM = + ChatCompletionSystemMessageParam.builder().content(STRING).build() + private val ASSIST_MESSAGE_PARAM = + ChatCompletionAssistantMessageParam.builder().content(STRING).build() + private val TOOL_MESSAGE_PARAM = + ChatCompletionToolMessageParam.builder().content(STRING).toolCallId(STRING).build() + private val FUNC_MESSAGE_PARAM = + ChatCompletionFunctionMessageParam.builder().content(STRING).name(STRING).build() + private val MESSAGE_PARAM = ChatCompletionMessageParam.ofUser(USER_MESSAGE_PARAM) + + private val DEV_MESSAGE_PARAM_CONTENT = + ChatCompletionDeveloperMessageParam.Content.ofText(STRING) + private val SYS_MESSAGE_PARAM_CONTENT = + ChatCompletionSystemMessageParam.Content.ofText(STRING) + private val USER_MESSAGE_PARAM_CONTENT = + ChatCompletionUserMessageParam.Content.ofText(STRING) + + private val PARAMS_BODY = + ChatCompletionCreateParams.Body.builder() + .messages(listOf(MESSAGE_PARAM)) + .model(CHAT_MODEL) + .build() + private val WEB_SEARCH_OPTIONS = + ChatCompletionCreateParams.WebSearchOptions.builder().build() + + private val FUNCTION_CALL_MODE = + ChatCompletionCreateParams.FunctionCall.FunctionCallMode.AUTO + private val FUNCTION_CALL_OPTION = + ChatCompletionFunctionCallOption.builder().name(STRING).build() + private val FUNCTION_CALL = + ChatCompletionCreateParams.FunctionCall.ofFunctionCallOption(FUNCTION_CALL_OPTION) + + private val FUNCTION = ChatCompletionCreateParams.Function.builder().name(STRING).build() + private val METADATA = ChatCompletionCreateParams.Metadata.builder().build() + private val MODALITY = ChatCompletionCreateParams.Modality.TEXT + private val FUNCTION_DEFINITION = FunctionDefinition.builder().name(STRING).build() + private val TOOL = ChatCompletionTool.builder().function(FUNCTION_DEFINITION).build() + + private val NAMED_TOOL_CHOICE_FUNCTION = + ChatCompletionNamedToolChoice.Function.builder().name(STRING).build() + private val NAMED_TOOL_CHOICE = + ChatCompletionNamedToolChoice.builder().function(NAMED_TOOL_CHOICE_FUNCTION).build() + private val TOOL_CHOICE_OPTION_AUTO = ChatCompletionToolChoiceOption.Auto.AUTO + private val TOOL_CHOICE_OPTION = + ChatCompletionToolChoiceOption.ofAuto(TOOL_CHOICE_OPTION_AUTO) + + private val HEADERS = Headers.builder().build() + private val QUERY_PARAMS = QueryParams.builder().build() + + // The list order follows the declaration order in `ChatCompletionCreateParams.Builder` for + // easier maintenance. + @JvmStatic + private fun builderDelegationTestCases() = + listOf( + DelegationWriteTestCase("body", PARAMS_BODY), + DelegationWriteTestCase("messages", LIST), + DelegationWriteTestCase("messages", JSON_FIELD), + DelegationWriteTestCase("addMessage", MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", DEV_MESSAGE_PARAM), + DelegationWriteTestCase("addDeveloperMessage", DEV_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addDeveloperMessage", STRING), + DelegationWriteTestCase("addDeveloperMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", SYS_MESSAGE_PARAM), + DelegationWriteTestCase("addSystemMessage", SYS_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addSystemMessage", STRING), + DelegationWriteTestCase("addSystemMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", USER_MESSAGE_PARAM), + DelegationWriteTestCase("addUserMessage", USER_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addUserMessage", STRING), + DelegationWriteTestCase("addUserMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", ASSIST_MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", MESSAGE), + DelegationWriteTestCase("addMessage", TOOL_MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", FUNC_MESSAGE_PARAM), + DelegationWriteTestCase("model", CHAT_MODEL), + DelegationWriteTestCase("model", JSON_FIELD), + DelegationWriteTestCase("model", STRING), + DelegationWriteTestCase("audio", NULLABLE), + DelegationWriteTestCase("audio", OPTIONAL), + DelegationWriteTestCase("audio", JSON_FIELD), + DelegationWriteTestCase("frequencyPenalty", NULLABLE_DOUBLE), + DelegationWriteTestCase("frequencyPenalty", DOUBLE), + DelegationWriteTestCase("frequencyPenalty", OPTIONAL), + DelegationWriteTestCase("frequencyPenalty", JSON_FIELD), + DelegationWriteTestCase("functionCall", FUNCTION_CALL), + DelegationWriteTestCase("functionCall", JSON_FIELD), + DelegationWriteTestCase("functionCall", FUNCTION_CALL_MODE), + DelegationWriteTestCase("functionCall", FUNCTION_CALL_OPTION), + DelegationWriteTestCase("functions", LIST), + DelegationWriteTestCase("functions", JSON_FIELD), + DelegationWriteTestCase("addFunction", FUNCTION), + DelegationWriteTestCase("logitBias", NULLABLE), + DelegationWriteTestCase("logitBias", OPTIONAL), + DelegationWriteTestCase("logitBias", JSON_FIELD), + DelegationWriteTestCase("logprobs", NULLABLE_BOOLEAN), + DelegationWriteTestCase("logprobs", BOOLEAN), + DelegationWriteTestCase("logprobs", OPTIONAL), + DelegationWriteTestCase("logprobs", JSON_FIELD), + DelegationWriteTestCase("maxCompletionTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxCompletionTokens", LONG), + DelegationWriteTestCase("maxCompletionTokens", OPTIONAL), + DelegationWriteTestCase("maxCompletionTokens", JSON_FIELD), + DelegationWriteTestCase("maxTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxTokens", LONG), + DelegationWriteTestCase("maxTokens", OPTIONAL), + DelegationWriteTestCase("maxTokens", JSON_FIELD), + DelegationWriteTestCase("metadata", METADATA), + DelegationWriteTestCase("metadata", OPTIONAL), + DelegationWriteTestCase("metadata", JSON_FIELD), + DelegationWriteTestCase("modalities", LIST), + DelegationWriteTestCase("modalities", OPTIONAL), + DelegationWriteTestCase("modalities", JSON_FIELD), + DelegationWriteTestCase("addModality", MODALITY), + DelegationWriteTestCase("n", NULLABLE_LONG), + DelegationWriteTestCase("n", LONG), + DelegationWriteTestCase("n", OPTIONAL), + DelegationWriteTestCase("n", JSON_FIELD), + DelegationWriteTestCase("parallelToolCalls", BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", JSON_FIELD), + DelegationWriteTestCase("prediction", NULLABLE), + DelegationWriteTestCase("prediction", OPTIONAL), + DelegationWriteTestCase("prediction", JSON_FIELD), + DelegationWriteTestCase("presencePenalty", NULLABLE_DOUBLE), + DelegationWriteTestCase("presencePenalty", DOUBLE), + DelegationWriteTestCase("presencePenalty", OPTIONAL), + DelegationWriteTestCase("presencePenalty", JSON_FIELD), + DelegationWriteTestCase("reasoningEffort", NULLABLE), + DelegationWriteTestCase("reasoningEffort", OPTIONAL), + DelegationWriteTestCase("reasoningEffort", JSON_FIELD), + // `responseFormat()` is a special case and has its own unit test. + DelegationWriteTestCase("seed", NULLABLE_LONG), + DelegationWriteTestCase("seed", LONG), + DelegationWriteTestCase("seed", OPTIONAL), + DelegationWriteTestCase("seed", JSON_FIELD), + DelegationWriteTestCase("serviceTier", NULLABLE), + DelegationWriteTestCase("serviceTier", OPTIONAL), + DelegationWriteTestCase("serviceTier", JSON_FIELD), + DelegationWriteTestCase("stop", NULLABLE), + DelegationWriteTestCase("stop", OPTIONAL), + DelegationWriteTestCase("stop", JSON_FIELD), + DelegationWriteTestCase("stop", STRING), + DelegationWriteTestCase("stopOfStrings", LIST), + DelegationWriteTestCase("store", NULLABLE_BOOLEAN), + DelegationWriteTestCase("store", BOOLEAN), + DelegationWriteTestCase("store", OPTIONAL), + DelegationWriteTestCase("store", JSON_FIELD), + DelegationWriteTestCase("streamOptions", NULLABLE), + DelegationWriteTestCase("streamOptions", OPTIONAL), + DelegationWriteTestCase("streamOptions", JSON_FIELD), + DelegationWriteTestCase("temperature", NULLABLE_DOUBLE), + DelegationWriteTestCase("temperature", DOUBLE), + DelegationWriteTestCase("temperature", OPTIONAL), + DelegationWriteTestCase("temperature", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTION), + DelegationWriteTestCase("toolChoice", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTION_AUTO), + DelegationWriteTestCase("toolChoice", NAMED_TOOL_CHOICE), + DelegationWriteTestCase("tools", LIST), + DelegationWriteTestCase("tools", JSON_FIELD), + DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("topLogprobs", NULLABLE_LONG), + DelegationWriteTestCase("topLogprobs", LONG), + DelegationWriteTestCase("topLogprobs", OPTIONAL), + DelegationWriteTestCase("topLogprobs", JSON_FIELD), + DelegationWriteTestCase("topP", NULLABLE_DOUBLE), + DelegationWriteTestCase("topP", DOUBLE), + DelegationWriteTestCase("topP", OPTIONAL), + DelegationWriteTestCase("topP", JSON_FIELD), + DelegationWriteTestCase("user", STRING), + DelegationWriteTestCase("user", JSON_FIELD), + DelegationWriteTestCase("webSearchOptions", WEB_SEARCH_OPTIONS), + DelegationWriteTestCase("webSearchOptions", JSON_FIELD), + DelegationWriteTestCase("additionalBodyProperties", MAP), + DelegationWriteTestCase("putAdditionalBodyProperty", STRING, JSON_VALUE), + DelegationWriteTestCase("putAllAdditionalBodyProperties", MAP), + DelegationWriteTestCase("removeAdditionalBodyProperty", STRING), + DelegationWriteTestCase("removeAllAdditionalBodyProperties", SET), + DelegationWriteTestCase("additionalHeaders", HEADERS), + DelegationWriteTestCase("additionalHeaders", MAP), + DelegationWriteTestCase("putAdditionalHeader", STRING, STRING), + DelegationWriteTestCase("putAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("putAllAdditionalHeaders", MAP), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("replaceAllAdditionalHeaders", MAP), + DelegationWriteTestCase("removeAdditionalHeaders", STRING), + DelegationWriteTestCase("removeAllAdditionalHeaders", SET), + DelegationWriteTestCase("additionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("additionalQueryParams", MAP), + DelegationWriteTestCase("putAdditionalQueryParam", STRING, STRING), + DelegationWriteTestCase("putAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("putAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("removeAdditionalQueryParams", STRING), + DelegationWriteTestCase("removeAllAdditionalQueryParams", SET), + ) + } + + // New instances of the `mockBuilderDelegate` and `builderDelegator` are required for each test + // case (each test case runs in its own instance of the test class). + private val mockBuilderDelegate: ChatCompletionCreateParams.Builder = + mock(ChatCompletionCreateParams.Builder::class.java) + private val builderDelegator = + StructuredChatCompletionCreateParams.builder().inject(mockBuilderDelegate) + + @Test + fun allBuilderDelegateFunctionsExistInDelegator() { + // The delegator class does not implement the various `responseFormat` functions of the + // delegate class. + checkAllDelegation(mockBuilderDelegate::class, builderDelegator::class, "responseFormat") + } + + @Test + fun allBuilderDelegatorFunctionsExistInDelegate() { + // The delegator implements a different `responseFormat` function from those overloads in + // the delegate class. + checkAllDelegation(builderDelegator::class, mockBuilderDelegate::class, "responseFormat") + } + + @Test + fun allBuilderDelegatorFunctionsAreTested() { + checkAllDelegatorWriteFunctionsAreTested( + builderDelegator::class, + builderDelegationTestCases(), + exceptionalTestedFns = setOf("responseFormat"), + nonDelegatingFns = setOf("build", "wrap", "inject"), + ) + } + + @ParameterizedTest + @MethodSource("builderDelegationTestCases") + fun `delegation of Builder write functions`(testCase: DelegationWriteTestCase) { + checkOneDelegationWrite(builderDelegator, mockBuilderDelegate, testCase) + } + + @Test + fun `delegation of responseFormat`() { + // Special unit test case as the delegator method signature does not match that of the + // delegate method. + val delegatorTestCase = DelegationWriteTestCase("responseFormat", X::class.java) + val delegatorMethod = findDelegationMethod(builderDelegator, delegatorTestCase) + val mockDelegateTestCase = + DelegationWriteTestCase("responseFormat", responseFormatFromClass(X::class.java)) + val mockDelegateMethod = findDelegationMethod(mockBuilderDelegate, mockDelegateTestCase) + + delegatorMethod.invoke(builderDelegator, delegatorTestCase.inputValues[0]) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockBuilderDelegate, times(1)).apply { + mockDelegateMethod.invoke(mockBuilderDelegate, mockDelegateTestCase.inputValues[0]) + } + verifyNoMoreInteractions(mockBuilderDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt new file mode 100644 index 00000000..939f9a53 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt @@ -0,0 +1,129 @@ +package com.openai.models.chat.completions + +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.JsonField +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletionMessage] class (delegator) and its delegation of most + * functions to a wrapped [ChatCompletionMessage] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionMessageTest { + companion object { + private val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + + // The list order follows the declaration order in `ChatCompletionMessage` for easier + // maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + // `content()` is a special case and has its own test function. + DelegationReadTestCase("refusal", OPTIONAL), + DelegationReadTestCase("_role", JSON_VALUE), + DelegationReadTestCase("annotations", OPTIONAL), + DelegationReadTestCase("audio", OPTIONAL), + DelegationReadTestCase("functionCall", OPTIONAL), + DelegationReadTestCase("toolCalls", OPTIONAL), + // `_content()` is a special case and has its own test function. + DelegationReadTestCase("_refusal", JSON_FIELD), + DelegationReadTestCase("_annotations", JSON_FIELD), + DelegationReadTestCase("_audio", JSON_FIELD), + DelegationReadTestCase("_functionCall", JSON_FIELD), + DelegationReadTestCase("_toolCalls", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + DelegationReadTestCase("validate", MESSAGE), + // For this boolean function, call with both possible values to ensure that any + // hard-coding or default value will not result in a false positive test. + DelegationReadTestCase("isValid", true), + DelegationReadTestCase("isValid", false), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: ChatCompletionMessage = mock(ChatCompletionMessage::class.java) + private val delegator = StructuredChatCompletionMessage(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder", "toParam") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + checkAllDelegation(delegator::class, mockDelegate::class) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("content", "_content"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of content`() { + // Input and output are different types, so this test is an exceptional case. + // `content()` (without an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of("{\"s\" : \"hello\"}") + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator.content() // Without an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isEqualTo(Optional.of(X("hello"))) + } + + @Test + fun `delegation of _content`() { + // Input and output are different types, so this test is an exceptional case. + // `_content()` delegates to `_content()` indirectly via the `content` field initializer. + val input = JsonField.of("{\"s\" : \"hello\"}") + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator._content() // With an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isEqualTo(JsonField.of(X("hello"))) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt new file mode 100644 index 00000000..40dc509f --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt @@ -0,0 +1,307 @@ +package com.openai.models.chat.completions + +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.JsonField +import com.openai.core.LONG +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import com.openai.errors.OpenAIInvalidDataException +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletion] class (delegator) and its delegation of most + * functions to a wrapped [ChatCompletion] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionTest { + companion object { + private val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + private val FINISH_REASON = ChatCompletion.Choice.FinishReason.STOP + private val CHOICE = + ChatCompletion.Choice.builder() + .message(MESSAGE) + .index(0L) + .finishReason(FINISH_REASON) + .logprobs( + ChatCompletion.Choice.Logprobs.builder().content(null).refusal(null).build() + ) + .build() + + // The list order follows the declaration order in `ChatCompletion` for easier maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + // `choices()` is a special case and has its own test function. + DelegationReadTestCase("created", LONG), + DelegationReadTestCase("model", STRING), + DelegationReadTestCase("_object_", JSON_VALUE), + DelegationReadTestCase("serviceTier", OPTIONAL), + DelegationReadTestCase("systemFingerprint", OPTIONAL), + DelegationReadTestCase("usage", OPTIONAL), + DelegationReadTestCase("_id", JSON_FIELD), + // `_choices()` is a special case and has its own test function. + DelegationReadTestCase("_created", JSON_FIELD), + DelegationReadTestCase("_model", JSON_FIELD), + DelegationReadTestCase("_serviceTier", JSON_FIELD), + DelegationReadTestCase("_systemFingerprint", JSON_FIELD), + DelegationReadTestCase("_usage", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + + @JvmStatic + private fun choiceDelegationTestCases() = + listOf( + DelegationReadTestCase("finishReason", FINISH_REASON), + DelegationReadTestCase("index", LONG), + DelegationReadTestCase("logprobs", OPTIONAL), + DelegationReadTestCase("_finishReason", JSON_FIELD), + // `message()` is a special case and has its own test function. + DelegationReadTestCase("_index", JSON_FIELD), + DelegationReadTestCase("_logprobs", JSON_FIELD), + // `_message()` is a special case and has its own test function. + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: ChatCompletion = mock(ChatCompletion::class.java) + private val delegator = StructuredChatCompletion(X::class.java, mockDelegate) + + private val mockChoiceDelegate: ChatCompletion.Choice = mock(ChatCompletion.Choice::class.java) + private val choiceDelegator = + StructuredChatCompletion.Choice(X::class.java, mockChoiceDelegate) + + @Test + fun allChatCompletionDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder") + } + + @Test + fun allChatCompletionDelegatorFunctionsExistInDelegate() { + checkAllDelegation(delegator::class, mockDelegate::class) + } + + @Test + fun allChoiceDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockChoiceDelegate::class, choiceDelegator::class, "toBuilder") + } + + @Test + fun allChoiceDelegatorFunctionsExistInDelegate() { + checkAllDelegation(choiceDelegator::class, mockChoiceDelegate::class) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("choices", "_choices", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @Test + fun allChoiceDelegatorFunctionsAreTested() { + checkAllDelegatorReadFunctionsAreTested( + choiceDelegator::class, + choiceDelegationTestCases(), + exceptionalTestedFns = setOf("message", "_message", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @ParameterizedTest + @MethodSource("choiceDelegationTestCases") + fun `delegation of Choice functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(choiceDelegator, mockChoiceDelegate, testCase) + } + + @Test + fun `delegation of choices`() { + // Input and output are different types, so this test is an exceptional case. + // `choices()` (without an underscore) delegates to `_choices()` (with an underscore) + // indirectly via the `choices` field initializer. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.choices() // Without an underscore. + + verify(mockDelegate, times(1))._choices() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].rawChoice).isEqualTo(CHOICE) + } + + @Test + fun `delegation of _choices`() { + // Input and output are different types, so this test is an exceptional case. + // `_choices()` delegates to `_choices()` indirectly via the `choices` field initializer. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator._choices() // With an underscore. + + verify(mockDelegate, times(1))._choices() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("_choices")[0].rawChoice).isEqualTo(CHOICE) + } + + @Test + fun `delegation of validate`() { + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.validate() + + // `validate()` calls `choices()` on the delegator which triggers the lazy initializer which + // calls `_choices()` on the delegate before `validate()` also calls `validate()` on the + // delegate. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isSameAs(delegator) + } + + @Test + fun `delegation of isValid when true`() { + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of isValid when false`() { + // Try with a `false` value to make sure `isValid()` is not just hard-coded to `true`. Do + // this by making `validate()` on the delegate throw an exception. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + `when`(mockDelegate.validate()).thenThrow(OpenAIInvalidDataException("test")) + val output = delegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isFalse + } + + @Test + fun `delegation of Choice-message`() { + // Input and output are different types, so this test is an exceptional case. + // `message()` (without an underscore) delegates to `_message()` (with an underscore) + // indirectly via the `message` field initializer. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.message() // Without an underscore. + + verify(mockChoiceDelegate, times(1))._message() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output.rawMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of Choice-_message`() { + // Input and output are different types, so this test is an exceptional case. + // `_message()` delegates to `_message()` indirectly via the `message` field initializer. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator._message() // With an underscore. + + verify(mockChoiceDelegate, times(1))._message() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output.getRequired("_message").rawMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of Choice-validate`() { + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.validate() + + // `validate()` calls `message()` on the delegator which triggers the lazy initializer which + // calls `_message()` on the delegate before `validate()` also calls `validate()` on the + // delegate. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isSameAs(choiceDelegator) + } + + @Test + fun `delegation of Choice-isValid when true`() { + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of Choice-isValid when false`() { + // Try with a `false` value to make sure `isValid()` is not just hard-coded to `true`. Do + // this by making `validate()` on the delegate throw an exception. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + `when`(mockChoiceDelegate.validate()).thenThrow(OpenAIInvalidDataException("test")) + val output = choiceDelegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isFalse + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt new file mode 100644 index 00000000..3d205072 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt @@ -0,0 +1,243 @@ +package com.openai.models.responses + +import com.openai.core.BOOLEAN +import com.openai.core.DOUBLE +import com.openai.core.DelegationWriteTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.LIST +import com.openai.core.LONG +import com.openai.core.MAP +import com.openai.core.NULLABLE +import com.openai.core.NULLABLE_BOOLEAN +import com.openai.core.NULLABLE_DOUBLE +import com.openai.core.NULLABLE_LONG +import com.openai.core.NULLABLE_STRING +import com.openai.core.OPTIONAL +import com.openai.core.SET +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorWriteFunctionsAreTested +import com.openai.core.checkOneDelegationWrite +import com.openai.core.findDelegationMethod +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.core.textConfigFromClass +import com.openai.models.ChatModel +import com.openai.models.Reasoning +import com.openai.models.ResponsesModel +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponseCreateParams] class (delegator) and its delegation of most + * functions to a wrapped [ResponseCreateParams] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseCreateParamsTest { + companion object { + private val CHAT_MODEL = ChatModel.GPT_4O + private val RESPONSES_MODEL = ResponsesModel.ofChat(CHAT_MODEL) + private val RESPONSES_ONLY_MODEL = ResponsesModel.ResponsesOnlyModel.O1_PRO + private val PARAMS_INPUT = ResponseCreateParams.Input.ofText(STRING) + + private val INCLUDABLE = ResponseIncludable.of(STRING) + private val METADATA = ResponseCreateParams.Metadata.builder().build() + private val SERVICE_TIER = ResponseCreateParams.ServiceTier.AUTO + private val REASONING = Reasoning.builder().build() + + private val TOOL_CHOICE_TYPE = ToolChoiceTypes.Type.FILE_SEARCH + private val TOOL_CHOICE_TYPES = ToolChoiceTypes.builder().type(TOOL_CHOICE_TYPE).build() + private val TOOL_CHOICE = ResponseCreateParams.ToolChoice.ofTypes(TOOL_CHOICE_TYPES) + private val TOOL_CHOICE_OPTIONS = ToolChoiceOptions.AUTO + private val TOOL_CHOICE_FUNCTION = ToolChoiceFunction.builder().name(STRING).build() + + private val FUNCTION_TOOL = + FunctionTool.builder().name(STRING).parameters(NULLABLE).strict(BOOLEAN).build() + private val FILE_SEARCH_TOOL = FileSearchTool.builder().vectorStoreIds(LIST).build() + private val WEB_SEARCH_TOOL = + WebSearchTool.builder().type(WebSearchTool.Type.WEB_SEARCH_PREVIEW).build() + private val COMPUTER_TOOL = + ComputerTool.builder() + .displayWidth(LONG) + .displayHeight(LONG) + .environment(ComputerTool.Environment.LINUX) + .build() + private val TOOL = Tool.ofFunction(FUNCTION_TOOL) + + private val HEADERS = Headers.builder().build() + private val QUERY_PARAMS = QueryParams.builder().build() + + // The list order follows the declaration order in `ResponseCreateParams.Builder` for + // easier maintenance. + @JvmStatic + private fun builderDelegationTestCases() = + listOf( + // The `body(...)` function is deliberately not supported: too messy. + DelegationWriteTestCase("input", PARAMS_INPUT), + DelegationWriteTestCase("input", JSON_FIELD), + DelegationWriteTestCase("input", STRING), + DelegationWriteTestCase("inputOfResponse", LIST), + DelegationWriteTestCase("model", RESPONSES_MODEL), + DelegationWriteTestCase("model", JSON_FIELD), + DelegationWriteTestCase("model", STRING), + DelegationWriteTestCase("model", CHAT_MODEL), + DelegationWriteTestCase("model", RESPONSES_ONLY_MODEL), + DelegationWriteTestCase("include", LIST), + DelegationWriteTestCase("include", OPTIONAL), + DelegationWriteTestCase("include", JSON_FIELD), + DelegationWriteTestCase("addInclude", INCLUDABLE), + DelegationWriteTestCase("instructions", NULLABLE_STRING), + DelegationWriteTestCase("instructions", OPTIONAL), + DelegationWriteTestCase("instructions", JSON_FIELD), + DelegationWriteTestCase("maxOutputTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxOutputTokens", LONG), + DelegationWriteTestCase("maxOutputTokens", OPTIONAL), + DelegationWriteTestCase("maxOutputTokens", JSON_FIELD), + DelegationWriteTestCase("metadata", METADATA), + DelegationWriteTestCase("metadata", OPTIONAL), + DelegationWriteTestCase("metadata", JSON_FIELD), + DelegationWriteTestCase("parallelToolCalls", NULLABLE_BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", OPTIONAL), + DelegationWriteTestCase("parallelToolCalls", JSON_FIELD), + DelegationWriteTestCase("previousResponseId", NULLABLE_STRING), + DelegationWriteTestCase("previousResponseId", OPTIONAL), + DelegationWriteTestCase("previousResponseId", JSON_FIELD), + DelegationWriteTestCase("reasoning", REASONING), + DelegationWriteTestCase("reasoning", OPTIONAL), + DelegationWriteTestCase("reasoning", JSON_FIELD), + DelegationWriteTestCase("serviceTier", SERVICE_TIER), + DelegationWriteTestCase("serviceTier", OPTIONAL), + DelegationWriteTestCase("serviceTier", JSON_FIELD), + DelegationWriteTestCase("store", NULLABLE_BOOLEAN), + DelegationWriteTestCase("store", BOOLEAN), + DelegationWriteTestCase("store", OPTIONAL), + DelegationWriteTestCase("store", JSON_FIELD), + DelegationWriteTestCase("temperature", NULLABLE_DOUBLE), + DelegationWriteTestCase("temperature", DOUBLE), + DelegationWriteTestCase("temperature", OPTIONAL), + DelegationWriteTestCase("temperature", JSON_FIELD), + // `text()` is a special case and has its own unit tests. + DelegationWriteTestCase("toolChoice", TOOL_CHOICE), + DelegationWriteTestCase("toolChoice", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTIONS), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_TYPES), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_FUNCTION), + DelegationWriteTestCase("tools", LIST), + DelegationWriteTestCase("tools", JSON_FIELD), + DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("addTool", FILE_SEARCH_TOOL), + DelegationWriteTestCase("addFileSearchTool", LIST), + DelegationWriteTestCase("addTool", FUNCTION_TOOL), + DelegationWriteTestCase("addTool", WEB_SEARCH_TOOL), + DelegationWriteTestCase("addTool", COMPUTER_TOOL), + DelegationWriteTestCase("topP", NULLABLE_DOUBLE), + DelegationWriteTestCase("topP", DOUBLE), + DelegationWriteTestCase("topP", OPTIONAL), + DelegationWriteTestCase("topP", JSON_FIELD), + DelegationWriteTestCase("truncation", NULLABLE), + DelegationWriteTestCase("truncation", OPTIONAL), + DelegationWriteTestCase("truncation", JSON_FIELD), + DelegationWriteTestCase("user", STRING), + DelegationWriteTestCase("user", JSON_FIELD), + DelegationWriteTestCase("additionalBodyProperties", MAP), + DelegationWriteTestCase("putAdditionalBodyProperty", STRING, JSON_VALUE), + DelegationWriteTestCase("putAllAdditionalBodyProperties", MAP), + DelegationWriteTestCase("removeAdditionalBodyProperty", STRING), + DelegationWriteTestCase("removeAllAdditionalBodyProperties", SET), + DelegationWriteTestCase("additionalHeaders", HEADERS), + DelegationWriteTestCase("additionalHeaders", MAP), + DelegationWriteTestCase("putAdditionalHeader", STRING, STRING), + DelegationWriteTestCase("putAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("putAllAdditionalHeaders", MAP), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("replaceAllAdditionalHeaders", MAP), + DelegationWriteTestCase("removeAdditionalHeaders", STRING), + DelegationWriteTestCase("removeAllAdditionalHeaders", SET), + DelegationWriteTestCase("additionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("additionalQueryParams", MAP), + DelegationWriteTestCase("putAdditionalQueryParam", STRING, STRING), + DelegationWriteTestCase("putAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("putAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("removeAdditionalQueryParams", STRING), + DelegationWriteTestCase("removeAllAdditionalQueryParams", SET), + ) + } + + // New instances of the `mockBuilderDelegate` and `builderDelegator` are required for each test + // case (each test case runs in its own instance of the test class). + private val mockBuilderDelegate: ResponseCreateParams.Builder = + mock(ResponseCreateParams.Builder::class.java) + private val builderDelegator = + StructuredResponseCreateParams.builder().inject(mockBuilderDelegate) + + @Test + fun allBuilderDelegateFunctionsExistInDelegator() { + // The delegator class does not implement the various `text` functions or the `body` + // function of the delegate class. + checkAllDelegation(mockBuilderDelegate::class, builderDelegator::class, "body", "text") + } + + @Test + fun allBuilderDelegatorFunctionsExistInDelegate() { + // The delegator implements a different `text` function from those overloads in the delegate + // class. + checkAllDelegation(builderDelegator::class, mockBuilderDelegate::class, "text") + } + + @Test + fun allBuilderDelegatorFunctionsAreTested() { + checkAllDelegatorWriteFunctionsAreTested( + builderDelegator::class, + builderDelegationTestCases(), + exceptionalTestedFns = setOf("text"), + nonDelegatingFns = setOf("build", "wrap", "inject"), + ) + } + + @ParameterizedTest + @MethodSource("builderDelegationTestCases") + fun `delegation of Builder write functions`(testCase: DelegationWriteTestCase) { + checkOneDelegationWrite(builderDelegator, mockBuilderDelegate, testCase) + } + + @Test + fun `delegation of text`() { + // Special unit test case as the delegator method signature does not match that of the + // delegate method. + val delegatorTestCase = DelegationWriteTestCase("text", X::class.java) + val delegatorMethod = findDelegationMethod(builderDelegator, delegatorTestCase) + val mockDelegateTestCase = + DelegationWriteTestCase("text", textConfigFromClass(X::class.java)) + val mockDelegateMethod = findDelegationMethod(mockBuilderDelegate, mockDelegateTestCase) + + delegatorMethod.invoke(builderDelegator, delegatorTestCase.inputValues[0]) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockBuilderDelegate, times(1)).apply { + mockDelegateMethod.invoke(mockBuilderDelegate, mockDelegateTestCase.inputValues[0]) + } + verifyNoMoreInteractions(mockBuilderDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt new file mode 100644 index 00000000..2ba6fa03 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputItemTest.kt @@ -0,0 +1,204 @@ +package com.openai.models.responses + +import com.openai.core.DelegationReadTestCase +import com.openai.core.LIST +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponseOutputItem] class (delegator) and its delegation of most + * functions to a wrapped [ResponseOutputItem] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseOutputItemTest { + companion object { + private val FILE_SEARCH_TOOL_CALL = + ResponseFileSearchToolCall.builder() + .id(STRING) + .queries(LIST) + .status(ResponseFileSearchToolCall.Status.COMPLETED) + .build() + private val FUNCTION_TOOL_CALL = + ResponseFunctionToolCall.builder().arguments(STRING).callId(STRING).name(STRING).build() + private val FUNCTION_WEB_SEARCH = + ResponseFunctionWebSearch.builder() + .id(STRING) + .status(ResponseFunctionWebSearch.Status.COMPLETED) + .build() + private val COMPUTER_TOOL_CALL = + ResponseComputerToolCall.builder() + .id(STRING) + .action(ResponseComputerToolCall.Action.ofWait()) + .callId(STRING) + .pendingSafetyChecks(listOf()) + .status(ResponseComputerToolCall.Status.COMPLETED) + .type(ResponseComputerToolCall.Type.COMPUTER_CALL) + .build() + private val REASONING_ITEM = + ResponseReasoningItem.builder().id(STRING).summary(listOf()).build() + private val MESSAGE = + ResponseOutputMessage.builder() + .id(STRING) + .content(listOf()) + .status(ResponseOutputMessage.Status.COMPLETED) + .build() + + // The list order follows the declaration order in `ResponseOutputItem` for easier + // maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + // `message()` is a special case and has its own test function. + DelegationReadTestCase("fileSearchCall", OPTIONAL), + DelegationReadTestCase("functionCall", OPTIONAL), + DelegationReadTestCase("webSearchCall", OPTIONAL), + DelegationReadTestCase("computerCall", OPTIONAL), + DelegationReadTestCase("reasoning", OPTIONAL), + // `isMessage()` is a special case and has its own test function. + // For the Boolean functions, call each in turn with both `true` and `false` to + // ensure that a return value is not hard-coded. + DelegationReadTestCase("isFileSearchCall", true), + DelegationReadTestCase("isFileSearchCall", false), + DelegationReadTestCase("isFunctionCall", true), + DelegationReadTestCase("isFunctionCall", false), + DelegationReadTestCase("isWebSearchCall", true), + DelegationReadTestCase("isWebSearchCall", false), + DelegationReadTestCase("isComputerCall", true), + DelegationReadTestCase("isComputerCall", false), + DelegationReadTestCase("isReasoning", true), + DelegationReadTestCase("isReasoning", false), + // `asMessage()` is a special case and has its own test function. + DelegationReadTestCase("asFileSearchCall", FILE_SEARCH_TOOL_CALL), + DelegationReadTestCase("asFunctionCall", FUNCTION_TOOL_CALL), + DelegationReadTestCase("asWebSearchCall", FUNCTION_WEB_SEARCH), + DelegationReadTestCase("asComputerCall", COMPUTER_TOOL_CALL), + DelegationReadTestCase("asReasoning", REASONING_ITEM), + DelegationReadTestCase("_json", OPTIONAL), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: ResponseOutputItem = mock(ResponseOutputItem::class.java) + private val delegator = StructuredResponseOutputItem(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + // `toBuilder()` is deliberately not implemented. `accept()` has a different signature. + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder", "accept") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + // `accept()` has a different signature. + checkAllDelegation(delegator::class, mockDelegate::class, "accept") + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = + setOf("message", "asMessage", "isMessage", "validate", "isValid", "accept"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of message`() { + // Input and output are different types, so this test is an exceptional case. + // The delegator's `message()` delegates to the delegate's `message()` indirectly via the + // delegator's `message` field initializer. + val input = Optional.of(MESSAGE) + `when`(mockDelegate.message()).thenReturn(input) + val output = delegator.message() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.get().rawMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of asMessage`() { + // Delegation function names do not match, so this test is an exceptional case. + // The delegator's `asMessage()` delegates to the delegate's `message()` (without the "as") + // indirectly via the delegator's `message` field initializer. + val input = Optional.of(MESSAGE) + `when`(mockDelegate.message()).thenReturn(input) + val output = delegator.asMessage() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.rawMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of isMessage`() { + // Delegation function names do not match, so this test is an exceptional case. + // The delegator's `isMessage()` delegates to the delegate's `message()` (without the "is") + // indirectly via the delegator's `message` field initializer. + val input = Optional.of(MESSAGE) + `when`(mockDelegate.message()).thenReturn(input) + val output = delegator.isMessage() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of validate`() { + `when`(mockDelegate.message()).thenReturn(Optional.of(MESSAGE)) + + delegator.validate() + + // Delegator's `validate()` does not call delegate's `validate()`. `message()` is called + // indirectly via the `message` field initializer. + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of isValid`() { + // `isValid` calls `validate()`, so the test is similar to that for `validate()`. + `when`(mockDelegate.message()).thenReturn(Optional.of(MESSAGE)) + + delegator.isValid() + + verify(mockDelegate, times(1)).message() + verifyNoMoreInteractions(mockDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt new file mode 100644 index 00000000..2b093cf5 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseOutputMessageTest.kt @@ -0,0 +1,285 @@ +package com.openai.models.responses + +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.JsonField +import com.openai.core.MAP +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import com.openai.errors.OpenAIInvalidDataException +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponseOutputMessage] class (delegator) and its delegation of most + * functions to a wrapped [ResponseOutputMessage] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseOutputMessageTest { + companion object { + private val MESSAGE_STATUS = ResponseOutputMessage.Status.COMPLETED + private val OUTPUT_TEXT = + ResponseOutputText.builder().annotations(listOf()).text(STRING).build() + private val OUTPUT_REFUSAL = ResponseOutputRefusal.builder().refusal(STRING).build() + private val CONTENT = ResponseOutputMessage.Content.ofOutputText(OUTPUT_TEXT) + + // The list order follows the declaration order in `ResponseOutputMessage` for easier + // maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + // `content()` is a special case and has its own test function. + DelegationReadTestCase("_role", JSON_VALUE), + DelegationReadTestCase("status", MESSAGE_STATUS), + DelegationReadTestCase("_type", JSON_VALUE), + DelegationReadTestCase("_id", JSON_FIELD), + // `_content()` is a special case and has its own test function. + DelegationReadTestCase("_status", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", MAP), + ) + + // The list order follows the declaration order in `ResponseOutputMessage.Content` for + // easier maintenance. + @JvmStatic + private fun contentDelegationTestCases() = + listOf( + // `outputText()` is a special case and has its own test function. + DelegationReadTestCase("refusal", OPTIONAL), + // For the Boolean functions, pass both `true` and `false` to ensure that one value + // is not hard-coded. + DelegationReadTestCase("isOutputText", true), + DelegationReadTestCase("isOutputText", false), + DelegationReadTestCase("isRefusal", true), + DelegationReadTestCase("isRefusal", false), + // `asOutputText()` is a special case and has its own test function. + DelegationReadTestCase("asRefusal", OUTPUT_REFUSAL), + DelegationReadTestCase("_json", OPTIONAL), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: ResponseOutputMessage = mock(ResponseOutputMessage::class.java) + private val delegator = StructuredResponseOutputMessage(X::class.java, mockDelegate) + + private val contentMockDelegate: ResponseOutputMessage.Content = + mock(ResponseOutputMessage.Content::class.java) + private val contentDelegator = + StructuredResponseOutputMessage.Content(X::class.java, contentMockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + checkAllDelegation(delegator::class, mockDelegate::class) + } + + @Test + fun allContentDelegateFunctionsExistInDelegator() { + // The `Content.accept()` function in the delegator takes a different type than that in the + // delegate, so there is no delegation from the former to the latter. `Content.toBuilder` is + // deliberately not implemented. + checkAllDelegation( + contentMockDelegate::class, + contentDelegator::class, + "toBuilder", + "accept", + ) + } + + @Test + fun allContentDelegatorFunctionsExistInDelegate() { + // The `Content.accept()` function in the delegator takes a different type than that in the + // delegate, so there is no delegation from the former to the latter. + checkAllDelegation(contentDelegator::class, contentMockDelegate::class, "accept") + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("content", "_content", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @Test + fun allContentDelegatorFunctionsAreTested() { + checkAllDelegatorReadFunctionsAreTested( + contentDelegator::class, + contentDelegationTestCases(), + exceptionalTestedFns = + setOf("outputText", "asOutputText", "validate", "isValid", "accept"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @ParameterizedTest + @MethodSource("contentDelegationTestCases") + fun `delegation of Content functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(contentDelegator, contentMockDelegate, testCase) + } + + @Test + fun `delegation of content`() { + // Input and output are different types, so this test is an exceptional case. + // `content()` (without an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of(listOf(CONTENT)) + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator.content() // Without an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].rawContent).isEqualTo(CONTENT) + } + + @Test + fun `delegation of _content`() { + // Input and output are different types, so this test is an exceptional case. + // `_content()` (with an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of(listOf(CONTENT)) + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator._content() // With an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("content")[0].rawContent).isEqualTo(CONTENT) + } + + @Test + fun `delegation of validate`() { + `when`(mockDelegate._content()).thenReturn(JsonField.of(listOf(CONTENT))) + + delegator.validate() + + // Delegator's `validate()` calls delegate's `validate()`. `_content` is called indirectly + // via the `content` field initializer. + verify(mockDelegate, times(1))._content() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of isValid`() { + // `isValid` calls `validate()`, so the test is similar to that for `validate()`. + `when`(mockDelegate._content()).thenReturn(JsonField.of(listOf(CONTENT))) + + delegator.isValid() + + verify(mockDelegate, times(1))._content() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of Content outputText`() { + // Input and output are different types, so this test is an exceptional case. The + // delegator's `outputText()` delegates to the delegate's `outputText()` indirectly via the + // `outputText` field initializer. + val input = + Optional.of( + ResponseOutputText.builder() + .annotations(listOf()) + .text("{\"s\" : \"hello\"}") + .build() + ) + `when`(contentMockDelegate.outputText()).thenReturn(input) + val output = contentDelegator.outputText() + + verify(contentMockDelegate, times(1)).outputText() + verifyNoMoreInteractions(contentMockDelegate) + + assertThat(output).isEqualTo(Optional.of(X("hello"))) + } + + @Test + fun `delegation of Content asOutputText`() { + // Input and output are different types, so this test is an exceptional case. The + // delegator's `asOutputText()` delegates to the delegate's `outputText()` indirectly via + // the + // `outputText` field initializer. + val input = + Optional.of( + ResponseOutputText.builder() + .annotations(listOf()) + .text("{\"s\" : \"hello\"}") + .build() + ) + `when`(contentMockDelegate.outputText()).thenReturn(input) + val output = contentDelegator.asOutputText() + + verify(contentMockDelegate, times(1)).outputText() + verifyNoMoreInteractions(contentMockDelegate) + + assertThat(output).isEqualTo(X("hello")) + } + + @Test + fun `delegation of Content asOutputText missing`() { + val input = Optional.ofNullable(null) + `when`(contentMockDelegate.outputText()).thenReturn(input) + + assertThatThrownBy { contentDelegator.asOutputText() } + .isInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("`outputText` is not present") + + verify(contentMockDelegate, times(1)).outputText() + verifyNoMoreInteractions(contentMockDelegate) + } + + @Test + fun `delegation of Content validate`() { + // No values or passed and only `this` is returned. + contentDelegator.validate() + + verify(contentMockDelegate, times(1)).validate() + verifyNoMoreInteractions(contentMockDelegate) + } + + @Test + fun `delegation of Content isValid`() { + contentDelegator.isValid() + + // `isValid()` calls `validate`, which then calls the mock delegate. + verify(contentMockDelegate, times(1)).validate() + verifyNoMoreInteractions(contentMockDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt new file mode 100644 index 00000000..9a1b76cd --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseTest.kt @@ -0,0 +1,197 @@ +package com.openai.models.responses + +import com.openai.core.BOOLEAN +import com.openai.core.DOUBLE +import com.openai.core.DelegationReadTestCase +import com.openai.core.JSON_FIELD +import com.openai.core.JSON_VALUE +import com.openai.core.JsonField +import com.openai.core.LIST +import com.openai.core.MAP +import com.openai.core.OPTIONAL +import com.openai.core.STRING +import com.openai.core.X +import com.openai.core.checkAllDelegation +import com.openai.core.checkAllDelegatorReadFunctionsAreTested +import com.openai.core.checkOneDelegationRead +import com.openai.models.ChatModel +import com.openai.models.ResponsesModel +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredResponse] class (delegator) and its delegation of most functions to + * a wrapped [Response] (delegate). The tests include confirmation of the following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredResponseTest { + companion object { + private val RESPONSES_MODEL = ResponsesModel.ofChat(ChatModel.GPT_4O) + + private val TOOL_CHOICE = + Response.ToolChoice.ofFunction(ToolChoiceFunction.builder().name(STRING).build()) + // A reasoning item is probably the simplest one to create. + private val OUTPUT_ITEM = + ResponseOutputItem.ofReasoning( + ResponseReasoningItem.builder() + .id(STRING) + .summary(listOf(ResponseReasoningItem.Summary.builder().text(STRING).build())) + .build() + ) + + // The list order follows the declaration order in `Response` for easier maintenance. + @JvmStatic + private fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + DelegationReadTestCase("createdAt", DOUBLE), + DelegationReadTestCase("error", OPTIONAL), + DelegationReadTestCase("incompleteDetails", OPTIONAL), + DelegationReadTestCase("instructions", OPTIONAL), + DelegationReadTestCase("metadata", OPTIONAL), + DelegationReadTestCase("model", RESPONSES_MODEL), + DelegationReadTestCase("_object_", JSON_VALUE), + // `output()` is a special case and has its own test function. + DelegationReadTestCase("parallelToolCalls", BOOLEAN), + DelegationReadTestCase("temperature", OPTIONAL), + DelegationReadTestCase("toolChoice", TOOL_CHOICE), + DelegationReadTestCase("tools", LIST), + DelegationReadTestCase("topP", OPTIONAL), + DelegationReadTestCase("maxOutputTokens", OPTIONAL), + DelegationReadTestCase("previousResponseId", OPTIONAL), + DelegationReadTestCase("reasoning", OPTIONAL), + DelegationReadTestCase("serviceTier", OPTIONAL), + DelegationReadTestCase("status", OPTIONAL), + DelegationReadTestCase("text", OPTIONAL), + DelegationReadTestCase("truncation", OPTIONAL), + DelegationReadTestCase("usage", OPTIONAL), + DelegationReadTestCase("user", OPTIONAL), + DelegationReadTestCase("_id", JSON_FIELD), + DelegationReadTestCase("_createdAt", JSON_FIELD), + DelegationReadTestCase("_error", JSON_FIELD), + DelegationReadTestCase("_incompleteDetails", JSON_FIELD), + DelegationReadTestCase("_instructions", JSON_FIELD), + DelegationReadTestCase("_metadata", JSON_FIELD), + DelegationReadTestCase("_model", JSON_FIELD), + // `_output()` is a special case and has its own test function. + DelegationReadTestCase("_parallelToolCalls", JSON_FIELD), + DelegationReadTestCase("_temperature", JSON_FIELD), + DelegationReadTestCase("_toolChoice", JSON_FIELD), + DelegationReadTestCase("_tools", JSON_FIELD), + DelegationReadTestCase("_topP", JSON_FIELD), + DelegationReadTestCase("_maxOutputTokens", JSON_FIELD), + DelegationReadTestCase("_previousResponseId", JSON_FIELD), + DelegationReadTestCase("_reasoning", JSON_FIELD), + DelegationReadTestCase("_serviceTier", JSON_FIELD), + DelegationReadTestCase("_status", JSON_FIELD), + DelegationReadTestCase("_text", JSON_FIELD), + DelegationReadTestCase("_truncation", JSON_FIELD), + DelegationReadTestCase("_usage", JSON_FIELD), + DelegationReadTestCase("_user", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", MAP), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + private val mockDelegate: Response = mock(Response::class.java) + private val delegator = StructuredResponse(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + checkAllDelegation(mockDelegate::class, delegator::class, "toBuilder") + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + checkAllDelegation(delegator::class, mockDelegate::class) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. A few delegator functions do not delegate, so + // no test function is necessary. + checkAllDelegatorReadFunctionsAreTested( + delegator::class, + delegationTestCases(), + exceptionalTestedFns = setOf("output", "_output", "validate", "isValid"), + nonDelegatingFns = setOf("equals", "hashCode", "toString"), + ) + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of output`() { + // Input and output are different types, so this test is an exceptional case. + // `output()` (without an underscore) delegates to `_output()` (with an underscore) + // indirectly via the `output` field initializer. + val input = JsonField.of(listOf(OUTPUT_ITEM)) + `when`(mockDelegate._output()).thenReturn(input) + val output = delegator.output() // Without an underscore. + + verify(mockDelegate, times(1))._output() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].rawOutputItem).isEqualTo(OUTPUT_ITEM) + } + + @Test + fun `delegation of _output`() { + // Input and output are different types, so this test is an exceptional case. + // `_output()` delegates to `_output()` indirectly via the `output` field initializer. + val input = JsonField.of(listOf(OUTPUT_ITEM)) + `when`(mockDelegate._output()).thenReturn(input) + val output = delegator._output() // With an underscore. + + verify(mockDelegate, times(1))._output() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("_output")[0].rawOutputItem).isEqualTo(OUTPUT_ITEM) + } + + @Test + fun `delegation of validate`() { + `when`(mockDelegate._output()).thenReturn(JsonField.of(listOf(OUTPUT_ITEM))) + + delegator.validate() + + // Delegator's `validate()` calls the delegate's `validate()`. Delegate's `_output()` is + // called indirectly via the `output` field initializer. + verify(mockDelegate, times(1))._output() // Indirect + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + } + + @Test + fun `delegation of isValid`() { + // `isValid()` calls `validate()` which delegates to `validate()`, so the test is + // more-or-less the same as for `validate()`. + `when`(mockDelegate._output()).thenReturn(JsonField.of(listOf(OUTPUT_ITEM))) + + delegator.isValid() + + verify(mockDelegate, times(1))._output() // Indirect + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java b/openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java new file mode 100644 index 00000000..9ef89183 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/ResponsesStructuredOutputsExample.java @@ -0,0 +1,73 @@ +package com.openai.example; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import java.util.List; + +public final class ResponsesStructuredOutputsExample { + + public static class Person { + @JsonPropertyDescription("The first name and surname of the person.") + public String name; + + public int birthYear; + + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; + + @Override + public String toString() { + return name + " (" + birthYear + '-' + deathYear + ')'; + } + } + + public static class Book { + public String title; + + public Person author; + + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + + public String genre; + + @JsonIgnore + public String isbn; + + @Override + public String toString() { + return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; + } + } + + public static class BookList { + public List books; + } + + private ResponsesStructuredOutputsExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + StructuredResponseCreateParams createParams = ResponseCreateParams.builder() + .input("List some famous late twentieth century novels.") + .text(BookList.class) + .model(ChatModel.GPT_4O) + .build(); + + client.responses().create(createParams).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(" - " + book)); + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java index 9d3ce9da..b4af9999 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsExample.java @@ -1,15 +1,54 @@ package com.openai.example; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; -import com.openai.core.JsonValue; import com.openai.models.ChatModel; -import com.openai.models.ResponseFormatJsonSchema; -import com.openai.models.ResponseFormatJsonSchema.JsonSchema; import com.openai.models.chat.completions.ChatCompletionCreateParams; -import java.util.Map; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; +import java.util.List; public final class StructuredOutputsExample { + + public static class Person { + @JsonPropertyDescription("The first name and surname of the person.") + public String name; + + public int birthYear; + + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; + + @Override + public String toString() { + return name + " (" + birthYear + '-' + deathYear + ')'; + } + } + + public static class Book { + public String title; + + public Person author; + + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + + public String genre; + + @JsonIgnore + public String isbn; + + @Override + public String toString() { + return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; + } + } + + public static class BookList { + public List books; + } + private StructuredOutputsExample() {} public static void main(String[] args) { @@ -18,26 +57,16 @@ public static void main(String[] args) { // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClient client = OpenAIOkHttpClient.fromEnv(); - // TODO: Update this once we support extracting JSON schemas from Java classes - JsonSchema.Schema schema = JsonSchema.Schema.builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", JsonValue.from(Map.of("employees", Map.of("items", Map.of("type", "string"))))) - .build(); - ChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() .model(ChatModel.GPT_4O_MINI) .maxCompletionTokens(2048) - .responseFormat(ResponseFormatJsonSchema.builder() - .jsonSchema(JsonSchema.builder() - .name("employee-list") - .schema(schema) - .build()) - .build()) - .addUserMessage("Who works at OpenAI?") + .responseFormat(BookList.class) + .addUserMessage("List some famous late twentieth century novels.") .build(); client.chat().completions().create(createParams).choices().stream() .flatMap(choice -> choice.message().content().stream()) - .forEach(System.out::println); + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(" - " + book)); } } diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsAsyncExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawAsyncExample.java similarity index 91% rename from openai-java-example/src/main/java/com/openai/example/StructuredOutputsAsyncExample.java rename to openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawAsyncExample.java index a645f6cb..f726e0cf 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsAsyncExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawAsyncExample.java @@ -9,8 +9,8 @@ import com.openai.models.chat.completions.ChatCompletionCreateParams; import java.util.Map; -public final class StructuredOutputsAsyncExample { - private StructuredOutputsAsyncExample() {} +public final class StructuredOutputsRawAsyncExample { + private StructuredOutputsRawAsyncExample() {} public static void main(String[] args) { // Configures using one of: @@ -18,7 +18,6 @@ public static void main(String[] args) { // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClientAsync client = OpenAIOkHttpClientAsync.fromEnv(); - // TODO: Update this once we support extracting JSON schemas from Java classes JsonSchema.Schema schema = JsonSchema.Schema.builder() .putAdditionalProperty("type", JsonValue.from("object")) .putAdditionalProperty( diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java new file mode 100644 index 00000000..3b9ff03a --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsRawExample.java @@ -0,0 +1,42 @@ +package com.openai.example; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonValue; +import com.openai.models.ChatModel; +import com.openai.models.ResponseFormatJsonSchema; +import com.openai.models.ResponseFormatJsonSchema.JsonSchema; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import java.util.Map; + +public final class StructuredOutputsRawExample { + private StructuredOutputsRawExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + JsonSchema.Schema schema = JsonSchema.Schema.builder() + .putAdditionalProperty("type", JsonValue.from("object")) + .putAdditionalProperty( + "properties", JsonValue.from(Map.of("employees", Map.of("items", Map.of("type", "string"))))) + .build(); + ChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .maxCompletionTokens(2048) + .responseFormat(ResponseFormatJsonSchema.builder() + .jsonSchema(JsonSchema.builder() + .name("employee-list") + .schema(schema) + .build()) + .build()) + .addUserMessage("Who works at OpenAI?") + .build(); + + client.chat().completions().create(createParams).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .forEach(System.out::println); + } +}