diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt index 8192bc190..3a786e287 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt @@ -12,7 +12,6 @@ import com.openai.core.JsonMissing import com.openai.core.JsonValue import com.openai.core.checkKnown import com.openai.core.checkRequired -import com.openai.core.toImmutable import com.openai.errors.OpenAIInvalidDataException import java.util.Collections import java.util.Objects @@ -21,7 +20,7 @@ import kotlin.jvm.optionals.getOrNull /** Represents an embedding vector returned by embedding endpoint. */ class Embedding private constructor( - private val embedding: JsonField>, + private val embedding: JsonField, private val index: JsonField, private val object_: JsonValue, private val additionalProperties: MutableMap, @@ -31,7 +30,7 @@ private constructor( private constructor( @JsonProperty("embedding") @ExcludeMissing - embedding: JsonField> = JsonMissing.of(), + embedding: JsonField = JsonMissing.of(), @JsonProperty("index") @ExcludeMissing index: JsonField = JsonMissing.of(), @JsonProperty("object") @ExcludeMissing object_: JsonValue = JsonMissing.of(), ) : this(embedding, index, object_, mutableMapOf()) @@ -43,7 +42,16 @@ private constructor( * @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is * unexpectedly missing or null (e.g. if the server responded with an unexpected value). */ - fun embedding(): List = embedding.getRequired("embedding") + fun embedding(): List = embeddingValue().asFloats() + + /** + * The embedding data in its original format (either float list or base64 string). This method + * provides efficient access to the embedding data without unnecessary conversions. + * + * @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is + * unexpectedly missing or null (e.g. if the server responded with an unexpected value). + */ + fun embeddingValue(): EmbeddingValue = embedding.getRequired("embedding") /** * The index of the embedding in the list of embeddings. @@ -71,7 +79,16 @@ private constructor( * * Unlike [embedding], this method doesn't throw if the JSON field has an unexpected type. */ - @JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField> = embedding + fun _embedding(): JsonField> = embedding.map { it.asFloats() } + + /** + * Returns the raw JSON value of [embedding]. + * + * Unlike [embeddingValue], this method doesn't throw if the JSON field has an unexpected type. + */ + @JsonProperty("embedding") + @ExcludeMissing + fun _embeddingValue(): JsonField = embedding /** * Returns the raw JSON value of [index]. @@ -109,25 +126,38 @@ private constructor( /** A builder for [Embedding]. */ class Builder internal constructor() { - private var embedding: JsonField>? = null + private var embeddingFloats: MutableList? = null + private var embedding: JsonField? = null private var index: JsonField? = null private var object_: JsonValue = JsonValue.from("embedding") private var additionalProperties: MutableMap = mutableMapOf() @JvmSynthetic internal fun from(embedding: Embedding) = apply { - this.embedding = embedding.embedding.map { it.toMutableList() } + this.embedding = embedding.embedding index = embedding.index object_ = embedding.object_ additionalProperties = embedding.additionalProperties.toMutableMap() } + /** + * The embedding vector. The length of vector depends on the model as listed in the + * [embedding guide](https://platform.openai.com/docs/guides/embeddings). + */ + fun embedding(embedding: EmbeddingValue) = embedding(JsonField.of(embedding)) + /** * The embedding vector, which is a list of floats. The length of vector depends on the * model as listed in the * [embedding guide](https://platform.openai.com/docs/guides/embeddings). */ - fun embedding(embedding: List) = embedding(JsonField.of(embedding)) + fun embedding(floats: List) = embedding(EmbeddingValue.ofFloats(floats)) + + /** + * The embedding vector, which is a base64 string. The length of vector depends on the model + * as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings). + */ + fun embedding(base64: String) = embedding(EmbeddingValue.ofBase64(base64)) /** * Sets [Builder.embedding] to an arbitrary JSON value. @@ -136,8 +166,9 @@ private constructor( * instead. This method is primarily for setting the field to an undocumented or not yet * supported value. */ - fun embedding(embedding: JsonField>) = apply { - this.embedding = embedding.map { it.toMutableList() } + fun embedding(embedding: JsonField) = apply { + embeddingFloats = null + this.embedding = embedding } /** @@ -146,10 +177,12 @@ private constructor( * @throws IllegalStateException if the field was previously set to a non-list. */ fun addEmbedding(embedding: Float) = apply { - this.embedding = - (this.embedding ?: JsonField.of(mutableListOf())).also { - checkKnown("embedding", it).add(embedding) - } + embeddingFloats = + (this.embedding?.let { checkKnown("embedding", it) }?.asFloats()?.toMutableList() + ?: embeddingFloats + ?: mutableListOf()) + .apply { add(embedding) } + this.embedding = null } /** The index of the embedding in the list of embeddings. */ @@ -211,7 +244,10 @@ private constructor( */ fun build(): Embedding = Embedding( - checkRequired("embedding", embedding).map { it.toImmutable() }, + checkRequired( + "embedding", + embedding ?: embeddingFloats?.let { JsonField.of(EmbeddingValue.ofFloats(it)) }, + ), checkRequired("index", index), object_, additionalProperties.toMutableMap(), @@ -225,7 +261,7 @@ private constructor( return@apply } - embedding() + embeddingValue().validate() index() _object_().let { if (it != JsonValue.from("embedding")) { @@ -250,7 +286,7 @@ private constructor( */ @JvmSynthetic internal fun validity(): Int = - (embedding.asKnown().getOrNull()?.size ?: 0) + + (embedding.asKnown().getOrNull()?.validity() ?: 0) + (if (index.asKnown().isPresent) 1 else 0) + object_.let { if (it == JsonValue.from("embedding")) 1 else 0 } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt index 6ea30e456..c9ebb4fee 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt @@ -581,7 +581,8 @@ private constructor( private var input: JsonField? = null private var model: JsonField? = null private var dimensions: JsonField = JsonMissing.of() - private var encodingFormat: JsonField = JsonMissing.of() + private var encodingFormat: JsonField = + JsonField.of(EncodingFormat.BASE64) private var user: JsonField = JsonMissing.of() private var additionalProperties: MutableMap = mutableMapOf() diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt new file mode 100644 index 000000000..1bd26ded0 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt @@ -0,0 +1,225 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.openai.models.embeddings + +import com.fasterxml.jackson.core.JsonGenerator +import com.fasterxml.jackson.core.ObjectCodec +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.SerializerProvider +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.databind.annotation.JsonSerialize +import com.fasterxml.jackson.module.kotlin.jacksonTypeRef +import com.openai.core.BaseDeserializer +import com.openai.core.BaseSerializer +import com.openai.core.JsonValue +import com.openai.core.allMaxBy +import com.openai.core.toImmutable +import com.openai.errors.OpenAIInvalidDataException +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.Base64 +import java.util.Objects +import java.util.Optional + +/** + * Represents embedding data that can be either a list of floats or base64-encoded string. This + * union type allows for efficient handling of both formats. + * + * This class is immutable - all instances are thread-safe and cannot be modified after creation. + */ +@JsonDeserialize(using = EmbeddingValue.Deserializer::class) +@JsonSerialize(using = EmbeddingValue.Serializer::class) +class EmbeddingValue +private constructor( + private val floats: List? = null, + private val base64: String? = null, + private val _json: JsonValue? = null, +) { + + fun floats(): Optional> = Optional.ofNullable(floats) + + fun base64(): Optional = Optional.ofNullable(base64) + + /** Returns true if this value contains a list of floats. */ + fun isFloats(): Boolean = floats != null + + /** Returns true if this value contains base64 string data. */ + fun isBase64(): Boolean = base64 != null + + /** + * Returns the embedding data as a list of floats. + * + * If this value represents base64 string data, then it's decoded into floats. + */ + fun asFloats(): List = + when { + floats != null -> floats + base64 != null -> decodeBase64ToFloats(base64) + else -> throw IllegalStateException("Invalid EmbeddingValue") + } + + /** + * Returns the embedding data as a base64 string. + * + * If this value represents a list of floats, then it's decoded into floats. + */ + fun asBase64(): String = + when { + base64 != null -> base64 + floats != null -> encodeFloatsAsBase64(floats) + else -> throw IllegalStateException("Invalid EmbeddingValue") + } + + fun _json(): JsonValue? = _json + + fun accept(visitor: Visitor): T = + when { + floats != null -> visitor.visitFloats(floats) + base64 != null -> visitor.visitBase64(base64) + else -> visitor.unknown(_json) + } + + fun validate() = apply { + accept( + object : Visitor { + override fun visitFloats(floats: List) {} + + override fun visitBase64(base64: String) {} + } + ) + } + + fun isValid(): Boolean = + try { + validate() + true + } catch (e: OpenAIInvalidDataException) { + false + } + + /** + * Returns a score indicating how many valid values are contained in this object. + * + * Used for best match union deserialization. + */ + @JvmSynthetic + internal fun validity(): Int = + when { + floats != null -> floats.size + base64 != null -> 1 + else -> 0 + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return /* spotless:off */ other is EmbeddingValue && floats == other.floats && base64 == other.base64 /* spotless:on */ + } + + override fun hashCode(): Int = /* spotless:off */ Objects.hash(floats, base64) /* spotless:on */ + + override fun toString(): String = + when { + floats != null -> "EmbeddingValue{floats=$floats}" + base64 != null -> "EmbeddingValue{base64=$base64}" + _json != null -> "EmbeddingValue{_unknown=$_json}" + else -> throw IllegalStateException("Invalid EmbeddingValue") + } + + companion object { + + @JvmStatic fun ofFloats(floats: List) = EmbeddingValue(floats = floats.toImmutable()) + + @JvmStatic fun ofBase64(base64: String) = EmbeddingValue(base64 = base64) + + /** + * Decodes a base64 string to a list of floats. Assumes the base64 string represents an + * array of 32-bit IEEE 754 floats in little-endian format. + */ + private fun decodeBase64ToFloats(base64: String): List { + val bytes = Base64.getDecoder().decode(base64) + val floats = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer() + return buildList(floats.remaining()) { + while (floats.hasRemaining()) { + add(floats.get()) + } + } + } + + /** + * Encodes a list of floats to a base64 string. Encodes the floats as an array of 32-bit + * IEEE 754 floats in little-endian format. + */ + private fun encodeFloatsAsBase64(floats: List): String { + val buffer = ByteBuffer.allocate(floats.size * 4).order(ByteOrder.LITTLE_ENDIAN) + floats.forEach { buffer.putFloat(it) } + return Base64.getEncoder().encodeToString(buffer.array()) + } + } + + /** + * An interface that defines how to map each variant of [EmbeddingValue] to a value of type [T]. + */ + interface Visitor { + + fun visitFloats(floats: List): T + + fun visitBase64(base64: String): T + + /** + * Maps an unknown variant of [EmbeddingValue] to a value of type [T]. + * + * An instance of [EmbeddingValue] can contain an unknown variant if it was deserialized + * from data that doesn't match any known variant. For example, if the SDK is on an older + * version than the API, then the API may respond with new variants that the SDK is unaware + * of. + * + * @throws OpenAIInvalidDataException in the default implementation. + */ + fun unknown(json: JsonValue?): T { + throw OpenAIInvalidDataException("Unknown EmbeddingValue: $json") + } + } + + internal class Deserializer : BaseDeserializer(EmbeddingValue::class) { + override fun ObjectCodec.deserialize(node: JsonNode): EmbeddingValue { + val json = JsonValue.fromJsonNode(node) + + val bestMatches = + sequenceOf( + tryDeserialize(node, jacksonTypeRef>())?.let { + EmbeddingValue(floats = it, _json = json) + }, + tryDeserialize(node, jacksonTypeRef())?.let { + EmbeddingValue(base64 = it, _json = json) + }, + ) + .filterNotNull() + .allMaxBy { it.validity() } + .toList() + + return when (bestMatches.size) { + 0 -> EmbeddingValue(_json = json) + 1 -> bestMatches.single() + else -> bestMatches.firstOrNull { it.isValid() } ?: bestMatches.first() + } + } + } + + internal class Serializer : BaseSerializer(EmbeddingValue::class) { + override fun serialize( + value: EmbeddingValue, + generator: JsonGenerator, + provider: SerializerProvider, + ) { + when { + value.floats != null -> generator.writeObject(value.floats) + value.base64 != null -> generator.writeObject(value.base64) + value._json != null -> generator.writeObject(value._json) + else -> throw IllegalStateException("Invalid EmbeddingValue") + } + } + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingTest.kt index 393c05cd4..57311de18 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingTest.kt @@ -17,6 +17,32 @@ internal class EmbeddingTest { assertThat(embedding.index()).isEqualTo(0L) } + @Test + fun create_setThenAdd() { + val embedding = + Embedding.builder() + .embedding(EmbeddingValue.ofFloats(listOf(1.0f, 2.0f))) + .addEmbedding(3.0f) + .index(0L) + .build() + + assertThat(embedding.embedding()).containsExactly(1.0f, 2.0f, 3.0f) + assertThat(embedding.index()).isEqualTo(0L) + } + + @Test + fun create_addThenSet() { + val embedding = + Embedding.builder() + .addEmbedding(3.0f) + .embedding(EmbeddingValue.ofFloats(listOf(1.0f, 2.0f))) + .index(0L) + .build() + + assertThat(embedding.embedding()).containsExactly(1.0f, 2.0f) + assertThat(embedding.index()).isEqualTo(0L) + } + @Test fun roundtrip() { val jsonMapper = jsonMapper() diff --git a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingValueTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingValueTest.kt new file mode 100644 index 000000000..663237712 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingValueTest.kt @@ -0,0 +1,35 @@ +package com.openai.models.embeddings + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +internal class EmbeddingValueTest { + + @Test + fun ofFloats() { + val floats = listOf(1.0f, 2.0f, 3.0f, 4.0f) + + val embeddingValue = EmbeddingValue.ofFloats(floats) + + assertThat(embeddingValue.isFloats()).isTrue() + assertThat(embeddingValue.isBase64()).isFalse() + assertThat(embeddingValue.floats()).hasValue(floats) + assertThat(embeddingValue.base64()).isEmpty + assertThat(embeddingValue.asFloats()).isEqualTo(floats) + assertThat(embeddingValue.asBase64()).isEqualTo("AACAPwAAAEAAAEBAAACAQA==") + } + + @Test + fun ofBase64() { + val base64 = "AACAPwAAAEAAAEBAAACAQA==" + + val embeddingValue = EmbeddingValue.ofBase64(base64) + + assertThat(embeddingValue.isFloats()).isFalse() + assertThat(embeddingValue.isBase64()).isTrue() + assertThat(embeddingValue.floats()).isEmpty + assertThat(embeddingValue.base64()).hasValue(base64) + assertThat(embeddingValue.asFloats()).containsExactly(1.0f, 2.0f, 3.0f, 4.0f) + assertThat(embeddingValue.asBase64()).isEqualTo(base64) + } +}