Skip to content

Commit

Permalink
Optimise nearest neighbour matching when there is only 0 or 1 match.
Browse files Browse the repository at this point in the history
This situation often occurs during single molecule tracing.
  • Loading branch information
aherbert committed Aug 24, 2023
1 parent 06ddc4c commit 622ebc7
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,22 @@ public static <T, U> int nearestNeighbour(List<T> verticesA, List<U> verticesB,
final List<ImmutableAssignment> neighbours =
findNeighbours(verticesA, verticesB, edges, threshold, sizeA, sizeB);

if (neighbours.isEmpty()) {
consume(verticesA, unmatchedA);
consume(verticesB, unmatchedB);
return 0;
}
if (neighbours.size() == 1) {
int a = neighbours.get(0).getTargetId();
int b = neighbours.get(0).getPredictedId();
if (matched != null) {
matched.accept(verticesA.get(a), verticesB.get(b));
}
consumeUnmatched(verticesA, unmatchedA, a);
consumeUnmatched(verticesB, unmatchedB, b);
return 1;
}

AssignmentComparator.sort(neighbours);

final boolean[] matchedA = new boolean[sizeA];
Expand Down Expand Up @@ -427,6 +443,25 @@ private static <T> void consumeUnmatched(List<T> list, Consumer<T> consumer, boo
}
}

/**
* Pass items from the list that are not the matched item to the consumer (null-safe).
*
* @param <T> the generic type
* @param list the list
* @param consumer the consumer (can be null)
* @param matched the matched item
*/
private static <T> void consumeUnmatched(List<T> list, Consumer<T> consumer, int matched) {
if (consumer != null) {
final int length = list.size();
for (int i = 0; i < length; i++) {
if (i != matched) {
consumer.accept(list.get(i));
}
}
}
}

/**
* Calculates a matching of a bipartite graph between two sets of weighted vertices using the
* minimum sum of distances. The distance must be less than or equal to the threshold to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ void testNearestNeighbourWithNoEdges() {
assertMatchingFunctionWithNoEdges(NearestNeighbourMatchingFunction.instance());
}

@Test
void testNearestNeighbourWithOneVertexA() {
assertMatchingFunctionWithOneVertexA(NearestNeighbourMatchingFunction.instance());
}

@Test
void testNearestNeighbourWithOneVertexB() {
assertMatchingFunctionWithOneVertexB(NearestNeighbourMatchingFunction.instance());
}

@Test
void testNearestNeighbourWithMaxCardinality() {
assertMatchingFunctionWithMaxCardinality(NearestNeighbourMatchingFunction.instance());
Expand Down Expand Up @@ -305,6 +315,16 @@ void testMinimumDistanceWithNoEdges() {
assertMatchingFunctionWithNoEdges(MinimumDistanceMatchingFunction.instance());
}

@Test
void testMinimumDistanceWithOneVertexA() {
assertMatchingFunctionWithOneVertexA(MinimumDistanceMatchingFunction.instance());
}

@Test
void testMinimumDistanceWithOneVertexB() {
assertMatchingFunctionWithOneVertexB(MinimumDistanceMatchingFunction.instance());
}

@Test
void testMinimumDistanceWithMaxCardinality() {
assertMatchingFunctionWithMaxCardinality(MinimumDistanceMatchingFunction.instance());
Expand Down Expand Up @@ -481,16 +501,57 @@ void testMinimumDistanceThrowsWithInfiniteRange() {
assertMatchingFunction(function, connections, 1, expected);
}

private static void
assertMatchingFunctionWithOneVertexA(MatchingFunction<Integer, Integer> function) {
final double[][] connections = new double[1][];
for (int i = 0; i < connections.length; i++) {
connections[i] = SimpleArrayUtils.newDoubleArray(6, Double.MAX_VALUE);
}

final int[][] expected = new int[0][0];
assertMatchingFunction(function, connections, 1, expected);

connections[0][3] = 0.125;
connections[0][4] = 0;
connections[0][5] = 0.25;
final int[][] expected2 = new int[][] {{0, 4}};
assertMatchingFunction(function, connections, 1, expected2);
}

private static void
assertMatchingFunctionWithOneVertexB(MatchingFunction<Integer, Integer> function) {
final double[][] connections = new double[6][];
for (int i = 0; i < connections.length; i++) {
connections[i] = SimpleArrayUtils.newDoubleArray(1, Double.MAX_VALUE);
}

final int[][] expected = new int[0][0];
assertMatchingFunction(function, connections, 1, expected);

connections[3][0] = 0.125;
connections[4][0] = 0;
connections[5][0] = 0.25;
final int[][] expected2 = new int[][] {{4, 0}};
assertMatchingFunction(function, connections, 1, expected2);
}

private static void
assertMatchingFunctionWithMaxCardinality(MatchingFunction<Integer, Integer> function) {
final double[][] connections = new double[6][];
for (int i = 0; i < connections.length; i++) {
connections[i] = SimpleArrayUtils.newDoubleArray(2, Double.MAX_VALUE);
}
connections[1][0] = 0;

int[][] expected = new int[0][0];
assertMatchingFunction(function, connections, 1, expected);

connections[4][1] = 0.75;
expected = new int[][] {{4, 1}};
assertMatchingFunction(function, connections, 1, expected);

final int[][] expected = new int[][] {{1, 0}, {4, 1}};
connections[1][0] = 0;
connections[4][1] = 0.75;
expected = new int[][] {{1, 0}, {4, 1}};
assertMatchingFunction(function, connections, 1, expected);
}

Expand Down

0 comments on commit 622ebc7

Please sign in to comment.