diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index fd0d4f792..318af8122 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -173,7 +173,8 @@ public int insertSorted(int newNode, float newScore) { if (size == nodes.length) { growArrays(); } - int insertionPoint = descSortFindRightMostInsertionPoint(newScore); + + int insertionPoint = descSortFindLeftMostInsertionPoint(newScore); if (duplicateExistsNear(insertionPoint, newNode, newScore)) { return -1; } @@ -282,12 +283,12 @@ public String toString() { return sb.toString(); } - protected final int descSortFindRightMostInsertionPoint(float newScore) { + protected final int descSortFindLeftMostInsertionPoint(float newScore) { int start = 0; int end = size - 1; while (start <= end) { int mid = (start + end) / 2; - if (scores[mid] < newScore) end = mid - 1; + if (scores[mid] <= newScore) end = mid - 1; else start = mid + 1; } return start; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index 4db8e32b8..7b228eb35 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -100,22 +100,21 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) { } /** - * Encodes the node ID and its similarity score as long. If two scores are equals, - * the smaller node ID wins. + * Encodes the node ID and its similarity score as long. * *

The most significant 32 bits represent the float score, encoded as a sortable int. * *

The less significant 32 bits represent the node ID. * - *

The bits representing the node ID are complemented to guarantee the win for the smaller node - * ID. + *

The bits representing the node ID are reversed to ensure no bias towards smaller or greater IDs + * when scores are equal. * *

The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that * has * *

The most significant 32 bits to 0 * - *

The less significant 32 bits represent the node ID. + *

The less significant 32 bits represent the encoded node ID. * * @param node the node ID * @param score the node score @@ -124,7 +123,7 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) { private long encode(int node, float score) { assert node >= 0 : node; return order.apply( - (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); + (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & Integer.reverse(node))); } private float decodeScore(long heapValue) { @@ -132,7 +131,7 @@ private float decodeScore(long heapValue) { } private int decodeNodeId(long heapValue) { - return (int) ~(order.apply(heapValue)); + return Integer.reverse((int) order.apply(heapValue)); } /** Removes the top element and returns its node id. */ diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java index 0f3586200..f4652d33a 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestNodeArray.java @@ -60,48 +60,50 @@ public void testScoresDescOrder() { neighbors.insertSorted(4, 1f); assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {4, 0, 3, 1}, neighbors); neighbors.insertSorted(5, 1.1f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {5, 4, 0, 3, 1}, neighbors); neighbors.insertSorted(6, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {5, 4, 0, 3, 6, 1}, neighbors); neighbors.insertSorted(7, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 4, 0, 3, 7, 6, 1}, neighbors); neighbors.removeIndex(2); assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 4, 3, 7, 6, 1}, neighbors); neighbors.removeIndex(0); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {4, 3, 7, 6, 1}, neighbors); neighbors.removeIndex(4); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {4, 3, 7, 6}, neighbors); neighbors.removeLast(); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {4, 3, 7}, neighbors); neighbors.insertSorted(8, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors); - assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors); + assertNodesEqual(new int[] {4, 8, 3, 7}, neighbors); } private void assertScoresEqual(float[] scores, NodeArray neighbors) { + assertEquals(scores.length, neighbors.size(), "Number of scores differs"); for (int i = 0; i < scores.length; i++) { assertEquals(scores[i], neighbors.getScore(i), 0.01f); } } private void assertNodesEqual(int[] nodes, NodeArray neighbors) { + assertEquals(nodes.length, neighbors.size(), "Number of nodes differs"); for (int i = 0; i < nodes.length; i++) { assertEquals(nodes[i], neighbors.getNode(i)); } @@ -181,7 +183,7 @@ public void testNoDuplicatesSameScores() { cna.insertSorted(3, 10.0f); cna.insertSorted(1, 10.0f); // This is a duplicate and should be ignored cna.insertSorted(3, 10.0f); // This is also a duplicate - assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes()); + assertArrayEquals(new int[] {3, 2, 1}, cna.copyDenseNodes()); assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f); validateSortedByScore(cna); } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java index c11b04ebd..53381019d 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java @@ -700,6 +700,47 @@ public void testZeroCentroid(boolean addHierarchy) { } } + @Test + public void testSameScoreWithCosineSimilarity() + { + testSameScoreWithCosineSimilarity(10); + testSameScoreWithCosineSimilarity(20); + testSameScoreWithCosineSimilarity(50); + testSameScoreWithCosineSimilarity(100); + testSameScoreWithCosineSimilarity(200); + testSameScoreWithCosineSimilarity(500); + testSameScoreWithCosineSimilarity(1000); + } + + private void testSameScoreWithCosineSimilarity(final int N) { + // Create N vectors which differ in their magnitude but have the same direction, so they would + // all have the exactly same cosine similarity to the query vector. + Random rand = getRandom(); + VectorFloat[] vectors = new VectorFloat[N]; + for (int i = 0; i < N; i++) { + float x = 0.01f + rand.nextFloat(); + vectors[i] = vectorTypeSupport.createFloatVector(new float[]{x, x}); + } + MockVectorValues vectorValues = MockVectorValues.fromValues(vectors); + + similarityFunction = VectorSimilarityFunction.COSINE; + GraphIndexBuilder builder = new GraphIndexBuilder(vectorValues, similarityFunction, 10, 20, 1.0f, 1.0f, false); + OnHeapGraphIndex graph = builder.build(vectorValues); + + VectorFloat query = vectorTypeSupport.createFloatVector(new float[]{0.5f, 0.5f}); + SearchResult result = GraphSearcher.search(query, N, vectorValues, similarityFunction, graph, Bits.ALL); + + // In perfect world, we should return all N vectors, but this is hard to guarantee considering + // the graph is built with a semi-randomized algorithm. And this is an edge case already, so + // we don't want to make the graph building algorithm more complex or less performant in order to satisfy + // this test. In a typical scenario we'll have many more vectors in the graph than the query limit, + // so missing some results is fine. We'd fall back to brute force search anyway if limit + // is the same order of magnitude as the graph size. + int minExpected = (int) (N * 0.5); + assertTrue("Should return almost all vectors, expected at least: " + minExpected + ", got: " + result.getNodes().length, + result.getNodes().length >= minExpected); + } + /** * Returns vectors evenly distributed around the upper unit semicircle. */