diff --git a/marklogic-client-api/src/main/java/com/marklogic/client/util/VectorUtil.java b/marklogic-client-api/src/main/java/com/marklogic/client/util/VectorUtil.java new file mode 100644 index 000000000..bd07255e8 --- /dev/null +++ b/marklogic-client-api/src/main/java/com/marklogic/client/util/VectorUtil.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2010-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.client.util; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Base64; + +/** + * Supports encoding and decoding vectors using the same approach as the vec:base64-encode and vec:base64-decode + * functions supported by the MarkLogic server. + * + * @since 7.2.0 + */ +public interface VectorUtil { + + /** + * @param vector + * @return a base64-encoded string representing the vector and using the same approach as the vec:base64-encode + * function supported by the MarkLogic server. + */ + static String base64Encode(float... vector) { + final int dimensions = vector.length; + ByteBuffer buffer = ByteBuffer.allocate(8 + 4 * dimensions); + buffer.order(ByteOrder.LITTLE_ENDIAN); + buffer.putInt(0); // version + buffer.putInt(dimensions); + for (float v : vector) { + buffer.putFloat(v); + } + return Base64.getEncoder().encodeToString(buffer.array()); + } + + /** + * @param encodedVector + * @return a vector represented by the base64-encoded string and using the same approach as the vec:base64-decode + * function supported by the MarkLogic server. + */ + static float[] base64Decode(String encodedVector) { + ByteBuffer buffer = ByteBuffer.wrap(Base64.getDecoder().decode(encodedVector)); + buffer.order(ByteOrder.LITTLE_ENDIAN); + + final int version = buffer.getInt(); + if (version != 0) { + throw new IllegalArgumentException("Unsupported vector version: " + version); + } + + final int dimensions = buffer.getInt(); + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = buffer.getFloat(); + } + return vector; + } +} diff --git a/marklogic-client-api/src/test/java/com/marklogic/client/util/VectorUtilTest.java b/marklogic-client-api/src/test/java/com/marklogic/client/util/VectorUtilTest.java new file mode 100644 index 000000000..42379eca7 --- /dev/null +++ b/marklogic-client-api/src/test/java/com/marklogic/client/util/VectorUtilTest.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2010-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.client.util; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.marklogic.client.test.Common; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class VectorUtilTest { + + private final float[] VECTOR = new float[]{3.14f, 1.59f, 2.65f}; + private final double ACCEPTABLE_DELTA = 0.0001; + + @Test + void encodeAndDecodeWithJavaClient() { + String encoded = VectorUtil.base64Encode(VECTOR); + assertEquals("AAAAAAMAAADD9UhAH4XLP5qZKUA=", encoded); + + float[] decoded = VectorUtil.base64Decode(encoded); + assertEquals(VECTOR.length, decoded.length); + for (int i = 0; i < VECTOR.length; i++) { + assertEquals(VECTOR[i], decoded[i], ACCEPTABLE_DELTA); + } + } + + @Test + void encodeAndDecodeWithServer() { + String encoded = VectorUtil.base64Encode(VECTOR); + assertEquals("AAAAAAMAAADD9UhAH4XLP5qZKUA=", encoded); + + ArrayNode decoded = (ArrayNode) Common.newEvalClient().newServerEval() + .xquery("vec:base64-decode('%s')".formatted(encoded)) + .evalAs(JsonNode.class); + + assertEquals(3, decoded.size()); + assertEquals(3.14f, decoded.get(0).asDouble(), ACCEPTABLE_DELTA); + assertEquals(1.59f, decoded.get(1).asDouble(), ACCEPTABLE_DELTA); + assertEquals(2.65f, decoded.get(2).asDouble(), ACCEPTABLE_DELTA); + } + + @Test + void encodeWithServerAndDecodeWithJavaClient() { + String encoded = Common.newEvalClient().newServerEval() + .xquery("vec:base64-encode(vec:vector((3.14, 1.59, 2.65)))") + .evalAs(String.class); + assertEquals("AAAAAAMAAADD9UhAH4XLP5qZKUA=", encoded); + + float[] decoded = VectorUtil.base64Decode(encoded); + assertEquals(VECTOR.length, decoded.length); + for (int i = 0; i < VECTOR.length; i++) { + assertEquals(VECTOR[i], decoded[i], ACCEPTABLE_DELTA); + } + } +}