Skip to content

feat(client): support base64 embeddings and use as default #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import com.openai.core.JsonMissing
import com.openai.core.JsonValue
import com.openai.core.checkKnown
import com.openai.core.checkRequired
import com.openai.core.toImmutable
import com.openai.errors.OpenAIInvalidDataException
import java.util.Collections
import java.util.Objects
Expand All @@ -21,7 +20,7 @@ import kotlin.jvm.optionals.getOrNull
/** Represents an embedding vector returned by embedding endpoint. */
class Embedding
private constructor(
private val embedding: JsonField<List<Float>>,
private val embedding: JsonField<EmbeddingValue>,
private val index: JsonField<Long>,
private val object_: JsonValue,
private val additionalProperties: MutableMap<String, JsonValue>,
Expand All @@ -31,7 +30,7 @@ private constructor(
private constructor(
@JsonProperty("embedding")
@ExcludeMissing
embedding: JsonField<List<Float>> = JsonMissing.of(),
embedding: JsonField<EmbeddingValue> = JsonMissing.of(),
@JsonProperty("index") @ExcludeMissing index: JsonField<Long> = JsonMissing.of(),
@JsonProperty("object") @ExcludeMissing object_: JsonValue = JsonMissing.of(),
) : this(embedding, index, object_, mutableMapOf())
Expand All @@ -43,7 +42,16 @@ private constructor(
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is
* unexpectedly missing or null (e.g. if the server responded with an unexpected value).
*/
fun embedding(): List<Float> = embedding.getRequired("embedding")
fun embedding(): List<Float> = embeddingValue().asFloats()

/**
* The embedding data in its original format (either float list or base64 string). This method
* provides efficient access to the embedding data without unnecessary conversions.
*
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is
* unexpectedly missing or null (e.g. if the server responded with an unexpected value).
*/
fun embeddingValue(): EmbeddingValue = embedding.getRequired("embedding")

/**
* The index of the embedding in the list of embeddings.
Expand Down Expand Up @@ -71,7 +79,16 @@ private constructor(
*
* Unlike [embedding], this method doesn't throw if the JSON field has an unexpected type.
*/
@JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField<List<Float>> = embedding
fun _embedding(): JsonField<List<Float>> = embedding.map { it.asFloats() }

/**
* Returns the raw JSON value of [embedding].
*
* Unlike [embeddingValue], this method doesn't throw if the JSON field has an unexpected type.
*/
@JsonProperty("embedding")
@ExcludeMissing
fun _embeddingValue(): JsonField<EmbeddingValue> = embedding

/**
* Returns the raw JSON value of [index].
Expand Down Expand Up @@ -109,25 +126,38 @@ private constructor(
/** A builder for [Embedding]. */
class Builder internal constructor() {

private var embedding: JsonField<MutableList<Float>>? = null
private var embeddingFloats: MutableList<Float>? = null
private var embedding: JsonField<EmbeddingValue>? = null
private var index: JsonField<Long>? = null
private var object_: JsonValue = JsonValue.from("embedding")
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()

@JvmSynthetic
internal fun from(embedding: Embedding) = apply {
this.embedding = embedding.embedding.map { it.toMutableList() }
this.embedding = embedding.embedding
index = embedding.index
object_ = embedding.object_
additionalProperties = embedding.additionalProperties.toMutableMap()
}

/**
* The embedding vector. The length of vector depends on the model as listed in the
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun embedding(embedding: EmbeddingValue) = embedding(JsonField.of(embedding))

/**
* The embedding vector, which is a list of floats. The length of vector depends on the
* model as listed in the
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun embedding(embedding: List<Float>) = embedding(JsonField.of(embedding))
fun embedding(floats: List<Float>) = embedding(EmbeddingValue.ofFloats(floats))

/**
* The embedding vector, which is a base64 string. The length of vector depends on the model
* as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*/
fun embedding(base64: String) = embedding(EmbeddingValue.ofBase64(base64))

/**
* Sets [Builder.embedding] to an arbitrary JSON value.
Expand All @@ -136,8 +166,9 @@ private constructor(
* instead. This method is primarily for setting the field to an undocumented or not yet
* supported value.
*/
fun embedding(embedding: JsonField<List<Float>>) = apply {
this.embedding = embedding.map { it.toMutableList() }
fun embedding(embedding: JsonField<EmbeddingValue>) = apply {
embeddingFloats = null
this.embedding = embedding
}

/**
Expand All @@ -146,10 +177,12 @@ private constructor(
* @throws IllegalStateException if the field was previously set to a non-list.
*/
fun addEmbedding(embedding: Float) = apply {
this.embedding =
(this.embedding ?: JsonField.of(mutableListOf())).also {
checkKnown("embedding", it).add(embedding)
}
embeddingFloats =
(this.embedding?.let { checkKnown("embedding", it) }?.asFloats()?.toMutableList()
?: embeddingFloats
?: mutableListOf())
.apply { add(embedding) }
this.embedding = null
}

/** The index of the embedding in the list of embeddings. */
Expand Down Expand Up @@ -211,7 +244,10 @@ private constructor(
*/
fun build(): Embedding =
Embedding(
checkRequired("embedding", embedding).map { it.toImmutable() },
checkRequired(
"embedding",
embedding ?: embeddingFloats?.let { JsonField.of(EmbeddingValue.ofFloats(it)) },
),
checkRequired("index", index),
object_,
additionalProperties.toMutableMap(),
Expand All @@ -225,7 +261,7 @@ private constructor(
return@apply
}

embedding()
embeddingValue().validate()
index()
_object_().let {
if (it != JsonValue.from("embedding")) {
Expand All @@ -250,7 +286,7 @@ private constructor(
*/
@JvmSynthetic
internal fun validity(): Int =
(embedding.asKnown().getOrNull()?.size ?: 0) +
(embedding.asKnown().getOrNull()?.validity() ?: 0) +
(if (index.asKnown().isPresent) 1 else 0) +
object_.let { if (it == JsonValue.from("embedding")) 1 else 0 }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,8 @@ private constructor(
private var input: JsonField<Input>? = null
private var model: JsonField<EmbeddingModel>? = null
private var dimensions: JsonField<Long> = JsonMissing.of()
private var encodingFormat: JsonField<EncodingFormat> = JsonMissing.of()
private var encodingFormat: JsonField<EncodingFormat> =
JsonField.of(EncodingFormat.BASE64)
private var user: JsonField<String> = JsonMissing.of()
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
// File generated from our OpenAPI spec by Stainless.

package com.openai.models.embeddings

import com.fasterxml.jackson.core.JsonGenerator
import com.fasterxml.jackson.core.ObjectCodec
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.SerializerProvider
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.fasterxml.jackson.databind.annotation.JsonSerialize
import com.fasterxml.jackson.module.kotlin.jacksonTypeRef
import com.openai.core.BaseDeserializer
import com.openai.core.BaseSerializer
import com.openai.core.JsonValue
import com.openai.core.allMaxBy
import com.openai.core.toImmutable
import com.openai.errors.OpenAIInvalidDataException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.Base64
import java.util.Objects
import java.util.Optional

/**
* Represents embedding data that can be either a list of floats or base64-encoded string. This
* union type allows for efficient handling of both formats.
*
* This class is immutable - all instances are thread-safe and cannot be modified after creation.
*/
@JsonDeserialize(using = EmbeddingValue.Deserializer::class)
@JsonSerialize(using = EmbeddingValue.Serializer::class)
class EmbeddingValue
private constructor(
private val floats: List<Float>? = null,
private val base64: String? = null,
private val _json: JsonValue? = null,
) {

fun floats(): Optional<List<Float>> = Optional.ofNullable(floats)

fun base64(): Optional<String> = Optional.ofNullable(base64)

/** Returns true if this value contains a list of floats. */
fun isFloats(): Boolean = floats != null

/** Returns true if this value contains base64 string data. */
fun isBase64(): Boolean = base64 != null

/**
* Returns the embedding data as a list of floats.
*
* If this value represents base64 string data, then it's decoded into floats.
*/
fun asFloats(): List<Float> =
when {
floats != null -> floats
base64 != null -> decodeBase64ToFloats(base64)
else -> throw IllegalStateException("Invalid EmbeddingValue")
}

/**
* Returns the embedding data as a base64 string.
*
* If this value represents a list of floats, then it's decoded into floats.
*/
fun asBase64(): String =
when {
base64 != null -> base64
floats != null -> encodeFloatsAsBase64(floats)
else -> throw IllegalStateException("Invalid EmbeddingValue")
}

fun _json(): JsonValue? = _json

fun <T> accept(visitor: Visitor<T>): T =
when {
floats != null -> visitor.visitFloats(floats)
base64 != null -> visitor.visitBase64(base64)
else -> visitor.unknown(_json)
}

fun validate() = apply {
accept(
object : Visitor<Unit> {
override fun visitFloats(floats: List<Float>) {}

override fun visitBase64(base64: String) {}
}
)
}

fun isValid(): Boolean =
try {
validate()
true
} catch (e: OpenAIInvalidDataException) {
false
}

/**
* Returns a score indicating how many valid values are contained in this object.
*
* Used for best match union deserialization.
*/
@JvmSynthetic
internal fun validity(): Int =
when {
floats != null -> floats.size
base64 != null -> 1
else -> 0
}

override fun equals(other: Any?): Boolean {
if (this === other) {
return true
}

return /* spotless:off */ other is EmbeddingValue && floats == other.floats && base64 == other.base64 /* spotless:on */
}

override fun hashCode(): Int = /* spotless:off */ Objects.hash(floats, base64) /* spotless:on */

override fun toString(): String =
when {
floats != null -> "EmbeddingValue{floats=$floats}"
base64 != null -> "EmbeddingValue{base64=$base64}"
_json != null -> "EmbeddingValue{_unknown=$_json}"
else -> throw IllegalStateException("Invalid EmbeddingValue")
}

companion object {

@JvmStatic fun ofFloats(floats: List<Float>) = EmbeddingValue(floats = floats.toImmutable())

@JvmStatic fun ofBase64(base64: String) = EmbeddingValue(base64 = base64)

/**
* Decodes a base64 string to a list of floats. Assumes the base64 string represents an
* array of 32-bit IEEE 754 floats in little-endian format.
*/
private fun decodeBase64ToFloats(base64: String): List<Float> {
val bytes = Base64.getDecoder().decode(base64)
val floats = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer()
return buildList(floats.remaining()) {
while (floats.hasRemaining()) {
add(floats.get())
}
}
}

/**
* Encodes a list of floats to a base64 string. Encodes the floats as an array of 32-bit
* IEEE 754 floats in little-endian format.
*/
private fun encodeFloatsAsBase64(floats: List<Float>): String {
val buffer = ByteBuffer.allocate(floats.size * 4).order(ByteOrder.LITTLE_ENDIAN)
floats.forEach { buffer.putFloat(it) }
return Base64.getEncoder().encodeToString(buffer.array())
}
}

/**
* An interface that defines how to map each variant of [EmbeddingValue] to a value of type [T].
*/
interface Visitor<out T> {

fun visitFloats(floats: List<Float>): T

fun visitBase64(base64: String): T

/**
* Maps an unknown variant of [EmbeddingValue] to a value of type [T].
*
* An instance of [EmbeddingValue] can contain an unknown variant if it was deserialized
* from data that doesn't match any known variant. For example, if the SDK is on an older
* version than the API, then the API may respond with new variants that the SDK is unaware
* of.
*
* @throws OpenAIInvalidDataException in the default implementation.
*/
fun unknown(json: JsonValue?): T {
throw OpenAIInvalidDataException("Unknown EmbeddingValue: $json")
}
}

internal class Deserializer : BaseDeserializer<EmbeddingValue>(EmbeddingValue::class) {
override fun ObjectCodec.deserialize(node: JsonNode): EmbeddingValue {
val json = JsonValue.fromJsonNode(node)

val bestMatches =
sequenceOf(
tryDeserialize(node, jacksonTypeRef<List<Float>>())?.let {
EmbeddingValue(floats = it, _json = json)
},
tryDeserialize(node, jacksonTypeRef<String>())?.let {
EmbeddingValue(base64 = it, _json = json)
},
)
.filterNotNull()
.allMaxBy { it.validity() }
.toList()

return when (bestMatches.size) {
0 -> EmbeddingValue(_json = json)
1 -> bestMatches.single()
else -> bestMatches.firstOrNull { it.isValid() } ?: bestMatches.first()
}
}
}

internal class Serializer : BaseSerializer<EmbeddingValue>(EmbeddingValue::class) {
override fun serialize(
value: EmbeddingValue,
generator: JsonGenerator,
provider: SerializerProvider,
) {
when {
value.floats != null -> generator.writeObject(value.floats)
value.base64 != null -> generator.writeObject(value.base64)
value._json != null -> generator.writeObject(value._json)
else -> throw IllegalStateException("Invalid EmbeddingValue")
}
}
}
}
Loading
Loading