Skip to content

Commit

Permalink
Add binarySearchFirst and binarySearchLast for stable search with dup…
Browse files Browse the repository at this point in the history
…licit elements
  • Loading branch information
kvr000 committed Sep 13, 2024
1 parent 76e2a2c commit f56bf0a
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 41 deletions.
166 changes: 143 additions & 23 deletions src/main/java/org/apache/commons/lang3/ArrayUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,8 @@ public static <T> T arraycopy(final T source, final int sourcePos, final T dest,
}

/**
* Searches element in array sorted by key.
* Searches element in array sorted by key. If there are multiple elements matching, it returns first occurrence.
* If the array is not sorted, the result is undefined.
*
* @param array
* array sorted by key field
Expand All @@ -1445,25 +1446,26 @@ public static <T> T arraycopy(final T source, final int sourcePos, final T dest,
* comparator for keys
*
* @return
* index of the search key, if it is contained in the array; otherwise, (-first_greater - 1).
* The first_greater is the index of lowest greater element in the list - if all elements are lower, the
* first_greater is defined as array.length.
* index of the first occurrence of search key, if it is contained in the array; otherwise,
* (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements
* are lower, the first_greater is defined as array.length.
*
* @param <T>
* type of array element
* @param <K>
* type of key
*/
public static <K, T> int binarySearch(
public static <K, T> int binarySearchFirst(
T[] array,
K key,
Function<T, K> keyExtractor, Comparator<? super K> comparator
) {
return binarySearch0(array, 0, array.length, key, keyExtractor, comparator);
return binarySearchFirst0(array, 0, array.length, key, keyExtractor, comparator);
}

/**
* Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive).
* Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are
* multiple elements matching, it returns first occurrence. If the array is not sorted, the result is undefined.
*
* @param array
* array sorted by key field
Expand All @@ -1479,9 +1481,9 @@ public static <K, T> int binarySearch(
* comparator for keys
*
* @return
* index of the search key, if it is contained in the array within specified range; otherwise,
* (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements
* are lower, the first_greater is defined as toIndex.
* index of the first occurrence of search key, if it is contained in the array within specified range;
* otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if
* all elements are lower, the first_greater is defined as toIndex.
*
* @throws ArrayIndexOutOfBoundsException
* when fromIndex or toIndex is out of array range
Expand All @@ -1493,28 +1495,124 @@ public static <K, T> int binarySearch(
* @param <K>
* type of key
*/
public static <T, K> int binarySearch(
public static <T, K> int binarySearchFirst(
T[] array,
int fromIndex, int toIndex,
K key,
Function<T, K> keyExtractor, Comparator<? super K> comparator
) {
if (fromIndex > toIndex) {
throw new IllegalArgumentException(
"fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
}
if (fromIndex < 0) {
throw new ArrayIndexOutOfBoundsException(fromIndex);
}
if (toIndex > array.length) {
throw new ArrayIndexOutOfBoundsException(toIndex);
checkRange(array.length, fromIndex, toIndex);

return binarySearchFirst0(array, fromIndex, toIndex, key, keyExtractor, comparator);
}

// common implementation for binarySearch methods, with same semantics:
private static <T, K> int binarySearchFirst0(
T[] array,
int fromIndex, int toIndex,
K key,
Function<T, K> keyExtractor, Comparator<? super K> comparator
) {
int l = fromIndex;
int h = toIndex - 1;

while (l <= h) {
final int m = (l + h) >>> 1; // unsigned shift to avoid overflow
final K value = keyExtractor.apply(array[m]);
final int c = comparator.compare(value, key);
if (c < 0) {
l = m + 1;
} else if (c > 0) {
h = m - 1;
} else if (l < h) {
// possibly multiple matching items remaining:
h = m;
} else {
// single matching item remaining:
return m;
}
}

return binarySearch0(array, fromIndex, toIndex, key, keyExtractor, comparator);
// not found, the l points to the lowest higher match:
return -l - 1;
}

/**
* Searches element in array sorted by key. If there are multiple elements matching, it returns last occurrence.
* If the array is not sorted, the result is undefined.
*
* @param array
* array sorted by key field
* @param key
* key to search for
* @param keyExtractor
* function to extract key from element
* @param comparator
* comparator for keys
*
* @return
* index of the last occurrence of search key, if it is contained in the array; otherwise,
* (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements
* are lower, the first_greater is defined as array.length.
*
* @param <T>
* type of array element
* @param <K>
* type of key
*/
public static <K, T> int binarySearchLast(
T[] array,
K key,
Function<T, K> keyExtractor, Comparator<? super K> comparator
) {
return binarySearchLast0(array, 0, array.length, key, keyExtractor, comparator);
}

/**
* Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are
* multiple elements matching, it returns last occurrence. If the array is not sorted, the result is undefined.
*
* @param array
* array sorted by key field
* @param fromIndex
* start index (inclusive)
* @param toIndex
* end index (exclusive)
* @param key
* key to search for
* @param keyExtractor
* function to extract key from element
* @param comparator
* comparator for keys
*
* @return
* index of the last occurrence of search key, if it is contained in the array within specified range;
* otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if
* all elements are lower, the first_greater is defined as toIndex.
*
* @throws ArrayIndexOutOfBoundsException
* when fromIndex or toIndex is out of array range
* @throws IllegalArgumentException
* when fromIndex is greater than toIndex
*
* @param <T>
* type of array element
* @param <K>
* type of key
*/
public static <T, K> int binarySearchLast(
T[] array,
int fromIndex, int toIndex,
K key,
Function<T, K> keyExtractor, Comparator<? super K> comparator
) {
checkRange(array.length, fromIndex, toIndex);

return binarySearchLast0(array, fromIndex, toIndex, key, keyExtractor, comparator);
}

// common implementation for binarySearch methods, with same semantics:
private static <T, K> int binarySearch0(
private static <T, K> int binarySearchLast0(
T[] array,
int fromIndex, int toIndex,
K key,
Expand All @@ -1531,8 +1629,16 @@ private static <T, K> int binarySearch0(
l = m + 1;
} else if (c > 0) {
h = m - 1;
} else if (m + 1 < h) {
// matching, more than two items remaining:
l = m;
} else if (m + 1 == h) {
// two items remaining, next loops would result in unchanged l and h, we have to choose m or h:
final K valueH = keyExtractor.apply(array[h]);
final int cH = comparator.compare(valueH, key);
return cH == 0 ? h : m;
} else {
// 0, found
// one item remaining, single match:
return m;
}
}
Expand Down Expand Up @@ -9573,4 +9679,18 @@ public static String[] toStringArray(final Object[] array, final String valueFor
public ArrayUtils() {
// empty
}

static void checkRange(int length, int fromIndex, int toIndex) {
if (fromIndex > toIndex) {
throw new IllegalArgumentException(
"fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
}
if (fromIndex < 0) {
throw new ArrayIndexOutOfBoundsException(fromIndex);
}
if (toIndex > length) {
throw new ArrayIndexOutOfBoundsException(toIndex);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,63 +30,99 @@
public class ArrayUtilsBinarySearchTest extends AbstractLangTest {

@Test
public void binarySearch_whenLowHigherThanEnd_throw() {
public void binarySearchFirst_whenLowHigherThanEnd_throw() {
final Data[] list = createList(0, 1);
assertThrowsExactly(IllegalArgumentException.class, () -> ArrayUtils.binarySearch(list, 1, 0, 0, Data::getValue, Integer::compare));
assertThrowsExactly(IllegalArgumentException.class, () ->
ArrayUtils.binarySearchFirst(list, 1, 0, 0, Data::getValue, Integer::compare));
}

@Test
public void binarySearch_whenLowNegative_throw() {
public void binarySearchFirst_whenLowNegative_throw() {
final Data[] list = createList(0, 1);
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> ArrayUtils.binarySearch(list, -1, 0, 0, Data::getValue, Integer::compare));
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () ->
ArrayUtils.binarySearchFirst(list, -1, 0, 0, Data::getValue, Integer::compare));
}

@Test
public void binarySearch_whenEndBeyondLength_throw() {
public void binarySearchFirst_whenEndBeyondLength_throw() {
final Data[] list = createList(0, 1);
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> ArrayUtils.binarySearch(list, 0, 3, 0, Data::getValue, Integer::compare));
assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () ->
ArrayUtils.binarySearchFirst(list, 0, 3, 0, Data::getValue, Integer::compare));
}

@Test
public void binarySearch_whenEmpty_returnM1() {
public void binarySearchLast_whenLowHigherThanEnd_throw() {
final Data[] list = createList(0, 1);
assertThrowsExactly(IllegalArgumentException.class, () ->
ArrayUtils.binarySearchLast(list, 1, 0, 0, Data::getValue, Integer::compare));
}

@Test
public void binarySearchFirst_whenEmpty_returnM1() {
final Data[] list = createList();
final int found = ArrayUtils.binarySearch(list, 0, Data::getValue, Integer::compare);
final int found = ArrayUtils.binarySearchFirst(list, 0, Data::getValue, Integer::compare);
assertEquals(-1, found);
}

@Test
public void binarySearch_whenExists_returnIndex() {
public void binarySearchFirst_whenExists_returnIndex() {
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
final int found = ArrayUtils.binarySearch(list, 9, Data::getValue, Integer::compare);
final int found = ArrayUtils.binarySearchFirst(list, 9, Data::getValue, Integer::compare);
assertEquals(5, found);
}

@Test
public void binarySearch_whenNotExistsMiddle_returnMinusInsertion() {
@Timeout(10)
public void binarySearchFirst_whenMultiple_returnFirst() {
final Data[] list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9);
for (int i = 0; i < list.length; ++i) {
if (i > 0 && list[i].value == list[i - 1].value) {
continue;
}
final int found = ArrayUtils.binarySearchFirst(list, list[i].value, Data::getValue, Integer::compare);
assertEquals(i, found);
}
}

@Test
@Timeout(10)
public void binarySearchLast_whenMultiple_returnFirst() {
final Data[] list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9);
for (int i = 0; i < list.length; ++i) {
if (i < list.length - 1 && list[i].value == list[i + 1].value) {
continue;
}
final int found = ArrayUtils.binarySearchLast(list, list[i].value, Data::getValue, Integer::compare);
assertEquals(i, found);
}
}

@Test
public void binarySearchFirst_whenNotExistsMiddle_returnMinusInsertion() {
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
final int found = ArrayUtils.binarySearch(list, 8, Data::getValue, Integer::compare);
final int found = ArrayUtils.binarySearchFirst(list, 8, Data::getValue, Integer::compare);
assertEquals(-6, found);
}

@Test
public void binarySearch_whenNotExistsBeginning_returnMinus1() {
public void binarySearchFirst_whenNotExistsBeginning_returnMinus1() {
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
final int found = ArrayUtils.binarySearch(list, -3, Data::getValue, Integer::compare);
final int found = ArrayUtils.binarySearchFirst(list, -3, Data::getValue, Integer::compare);
assertEquals(-1, found);
}

@Test
public void binarySearch_whenNotExistsEnd_returnMinusLength() {
public void binarySearchFirst_whenNotExistsEnd_returnMinusLength() {
final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25);
final int found = ArrayUtils.binarySearch(list, 29, Data::getValue, Integer::compare);
final int found = ArrayUtils.binarySearchFirst(list, 29, Data::getValue, Integer::compare);
assertEquals(-(list.length + 1), found);
}

@Test
@Timeout(10)
public void binarySearch_whenUnsorted_dontInfiniteLoop() {
public void binarySearchFirst_whenUnsorted_dontInfiniteLoop() {
final Data[] list = createList(7, 1, 4, 9, 11, 8);
final int found = ArrayUtils.binarySearch(list, 10, Data::getValue, Integer::compare);
final int found = ArrayUtils.binarySearchFirst(list, 10, Data::getValue, Integer::compare);
}

private Data[] createList(int... values) {
Expand Down

0 comments on commit f56bf0a

Please sign in to comment.