Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ public int insertSorted(int newNode, float newScore) {
if (size == nodes.length) {
growArrays();
}
int insertionPoint = descSortFindRightMostInsertionPoint(newScore);

int insertionPoint = descSortFindLeftMostInsertionPoint(newScore);
if (duplicateExistsNear(insertionPoint, newNode, newScore)) {
return -1;
}
Expand Down Expand Up @@ -282,12 +283,12 @@ public String toString() {
return sb.toString();
}

protected final int descSortFindRightMostInsertionPoint(float newScore) {
protected final int descSortFindLeftMostInsertionPoint(float newScore) {
int start = 0;
int end = size - 1;
while (start <= end) {
int mid = (start + end) / 2;
if (scores[mid] < newScore) end = mid - 1;
if (scores[mid] <= newScore) end = mid - 1;
else start = mid + 1;
}
return start;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,21 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
}

/**
* Encodes the node ID and its similarity score as long. If two scores are equals,
* the smaller node ID wins.
* Encodes the node ID and its similarity score as long.
*
* <p>The most significant 32 bits represent the float score, encoded as a sortable int.
*
* <p>The less significant 32 bits represent the node ID.
*
* <p>The bits representing the node ID are complemented to guarantee the win for the smaller node
* ID.
* <p>The bits representing the node ID are reversed to ensure no bias towards smaller or greater IDs
* when scores are equal.
*
* <p>The AND with 0xFFFFFFFFL (a long with first 32 bit as 1) is necessary to obtain a long that
* has
*
* <p>The most significant 32 bits to 0
*
* <p>The less significant 32 bits represent the node ID.
* <p>The less significant 32 bits represent the encoded node ID.
*
* @param node the node ID
* @param score the node score
Expand All @@ -124,15 +123,15 @@ public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
private long encode(int node, float score) {
assert node >= 0 : node;
return order.apply(
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
(((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & Integer.reverse(node)));
}

private float decodeScore(long heapValue) {
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
}

private int decodeNodeId(long heapValue) {
return (int) ~(order.apply(heapValue));
return Integer.reverse((int) order.apply(heapValue));
}

/** Removes the top element and returns its node id. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,48 +60,50 @@ public void testScoresDescOrder() {

neighbors.insertSorted(4, 1f);
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
assertNodesEqual(new int[] {4, 0, 3, 1}, neighbors);

neighbors.insertSorted(5, 1.1f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
assertNodesEqual(new int[] {5, 4, 0, 3, 1}, neighbors);

neighbors.insertSorted(6, 0.8f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
assertNodesEqual(new int[] {5, 4, 0, 3, 6, 1}, neighbors);

neighbors.insertSorted(7, 0.8f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
assertNodesEqual(new int[] {5, 4, 0, 3, 7, 6, 1}, neighbors);

neighbors.removeIndex(2);
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
assertNodesEqual(new int[] {5, 4, 3, 7, 6, 1}, neighbors);

neighbors.removeIndex(0);
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
assertNodesEqual(new int[] {4, 3, 7, 6, 1}, neighbors);

neighbors.removeIndex(4);
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
assertNodesEqual(new int[] {4, 3, 7, 6}, neighbors);

neighbors.removeLast();
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
assertNodesEqual(new int[] {0, 3, 1}, neighbors);
assertNodesEqual(new int[] {4, 3, 7}, neighbors);

neighbors.insertSorted(8, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
assertNodesEqual(new int[] {4, 8, 3, 7}, neighbors);
}

private void assertScoresEqual(float[] scores, NodeArray neighbors) {
assertEquals(scores.length, neighbors.size(), "Number of scores differs");
for (int i = 0; i < scores.length; i++) {
assertEquals(scores[i], neighbors.getScore(i), 0.01f);
}
}

private void assertNodesEqual(int[] nodes, NodeArray neighbors) {
assertEquals(nodes.length, neighbors.size(), "Number of nodes differs");
for (int i = 0; i < nodes.length; i++) {
assertEquals(nodes[i], neighbors.getNode(i));
}
Expand Down Expand Up @@ -181,7 +183,7 @@ public void testNoDuplicatesSameScores() {
cna.insertSorted(3, 10.0f);
cna.insertSorted(1, 10.0f); // This is a duplicate and should be ignored
cna.insertSorted(3, 10.0f); // This is also a duplicate
assertArrayEquals(new int[] {1, 2, 3}, cna.copyDenseNodes());
assertArrayEquals(new int[] {3, 2, 1}, cna.copyDenseNodes());
assertArrayEquals(new float[] {10.0f, 10.0f, 10.0f}, cna.copyDenseScores(), 0.01f);
validateSortedByScore(cna);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,47 @@ public void testZeroCentroid(boolean addHierarchy) {
}
}

@Test
public void testSameScoreWithCosineSimilarity()
{
testSameScoreWithCosineSimilarity(10);
testSameScoreWithCosineSimilarity(20);
testSameScoreWithCosineSimilarity(50);
testSameScoreWithCosineSimilarity(100);
testSameScoreWithCosineSimilarity(200);
testSameScoreWithCosineSimilarity(500);
testSameScoreWithCosineSimilarity(1000);
}

private void testSameScoreWithCosineSimilarity(final int N) {
// Create N vectors which differ in their magnitude but have the same direction, so they would
// all have the exactly same cosine similarity to the query vector.
Random rand = getRandom();
VectorFloat<?>[] vectors = new VectorFloat<?>[N];
for (int i = 0; i < N; i++) {
float x = 0.01f + rand.nextFloat();
vectors[i] = vectorTypeSupport.createFloatVector(new float[]{x, x});
}
MockVectorValues vectorValues = MockVectorValues.fromValues(vectors);

similarityFunction = VectorSimilarityFunction.COSINE;
GraphIndexBuilder builder = new GraphIndexBuilder(vectorValues, similarityFunction, 10, 20, 1.0f, 1.0f, false);
OnHeapGraphIndex graph = builder.build(vectorValues);

VectorFloat<?> query = vectorTypeSupport.createFloatVector(new float[]{0.5f, 0.5f});
SearchResult result = GraphSearcher.search(query, N, vectorValues, similarityFunction, graph, Bits.ALL);

// In perfect world, we should return all N vectors, but this is hard to guarantee considering
// the graph is built with a semi-randomized algorithm. And this is an edge case already, so
// we don't want to make the graph building algorithm more complex or less performant in order to satisfy
// this test. In a typical scenario we'll have many more vectors in the graph than the query limit,
// so missing some results is fine. We'd fall back to brute force search anyway if limit
// is the same order of magnitude as the graph size.
int minExpected = (int) (N * 0.5);
assertTrue("Should return almost all vectors, expected at least: " + minExpected + ", got: " + result.getNodes().length,
result.getNodes().length >= minExpected);
}

/**
* Returns vectors evenly distributed around the upper unit semicircle.
*/
Expand Down
Loading