From ea0df61c72c13c63cd9790931518317e28fc02f6 Mon Sep 17 00:00:00 2001 From: Zbynek Vyskovsky Date: Thu, 12 Sep 2024 21:45:35 -0700 Subject: [PATCH] Sorted List binarySearch functions --- .../commons/collections4/ListUtils.java | 229 ++++++++++++++++++ .../ListUtilsBinarySearchTest.java | 147 +++++++++++ 2 files changed, 376 insertions(+) create mode 100644 src/test/java/org/apache/commons/collections4/ListUtilsBinarySearchTest.java diff --git a/src/main/java/org/apache/commons/collections4/ListUtils.java b/src/main/java/org/apache/commons/collections4/ListUtils.java index 04744234ac..8ecd9e2ff4 100644 --- a/src/main/java/org/apache/commons/collections4/ListUtils.java +++ b/src/main/java/org/apache/commons/collections4/ListUtils.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.function.Function; import org.apache.commons.collections4.bag.HashBag; import org.apache.commons.collections4.functors.DefaultEquator; @@ -132,6 +133,221 @@ public int size() { } } + /** + * Searches element in list sorted by key. If there are multiple elements matching, it returns first occurrence. + * If the list is not sorted, the result is undefined. + * + * @param list + * list sorted by key field + * @param key + * key to search for + * @param keyExtractor + * function to extract key from element + * @param comparator + * comparator for keys + * + * @return + * index of the first occurrence of search key, if it is contained in the list; otherwise, + * (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements + * are lower, the first_greater is defined as list.size(). + * + * @param + * type of list element + * @param + * type of key + */ + public static int binarySearchFirst( + List list, + K key, + Function keyExtractor, Comparator comparator + ) { + return binarySearchFirst0(list, 0, list.size(), key, keyExtractor, comparator); + } + + /** + * Searches element in list sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are + * multiple elements matching, it returns first occurrence. If the list is not sorted, the result is undefined. + * + * @param list + * list sorted by key field + * @param fromIndex + * start index (inclusive) + * @param toIndex + * end index (exclusive) + * @param key + * key to search for + * @param keyExtractor + * function to extract key from element + * @param comparator + * comparator for keys + * + * @return + * index of the first occurrence of search key, if it is contained in the list within specified range; + * otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if + * all elements are lower, the first_greater is defined as toIndex. + * + * @throws ArrayIndexOutOfBoundsException + * when fromIndex or toIndex is out of array range + * @throws IllegalArgumentException + * when fromIndex is greater than toIndex + * + * @param + * type of list element + * @param + * type of key + */ + public static int binarySearchFirst( + List list, + int fromIndex, int toIndex, + K key, + Function keyExtractor, Comparator comparator + ) { + checkRange(list.size(), fromIndex, toIndex); + + return binarySearchFirst0(list, fromIndex, toIndex, key, keyExtractor, comparator); + } + + // common implementation for binarySearch methods, with same semantics: + private static int binarySearchFirst0( + List list, + int fromIndex, int toIndex, + K key, + Function keyExtractor, Comparator comparator + ) { + int l = fromIndex; + int h = toIndex - 1; + + while (l <= h) { + final int m = (l + h) >>> 1; // unsigned shift to avoid overflow + final K value = keyExtractor.apply(list.get(m)); + final int c = comparator.compare(value, key); + if (c < 0) { + l = m + 1; + } else if (c > 0) { + h = m - 1; + } else if (l < h) { + // possibly multiple matching items remaining: + h = m; + } else { + // single matching item remaining: + return m; + } + } + + // not found, the l points to the lowest higher match: + return -l - 1; + } + + /** + * Searches element in list sorted by key. If there are multiple elements matching, it returns last occurrence. + * If the list is not sorted, the result is undefined. + * + * @param list + * list sorted by key field + * @param key + * key to search for + * @param keyExtractor + * function to extract key from element + * @param comparator + * comparator for keys + * + * @return + * index of the last occurrence of search key, if it is contained in the list; otherwise, + * (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements + * are lower, the first_greater is defined as list.size() . + * + * @param + * type of array element + * @param + * type of key + */ + public static int binarySearchLast( + List list, + K key, + Function keyExtractor, Comparator comparator + ) { + return binarySearchLast0(list, 0, list.size(), key, keyExtractor, comparator); + } + + /** + * Searches element in list sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are + * multiple elements matching, it returns last occurrence. If the list is not sorted, the result is undefined. + * + * @param list + * list sorted by key field + * @param fromIndex + * start index (inclusive) + * @param toIndex + * end index (exclusive) + * @param key + * key to search for + * @param keyExtractor + * function to extract key from element + * @param comparator + * comparator for keys + * + * @return + * index of the last occurrence of search key, if it is contained in the list within specified range; + * otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if + * all elements are lower, the first_greater is defined as toIndex. + * + * @throws ArrayIndexOutOfBoundsException + * when fromIndex or toIndex is out of array range + * @throws IllegalArgumentException + * when fromIndex is greater than toIndex + * + * @param + * type of array element + * @param + * type of key + */ + public static int binarySearchLast( + List list, + int fromIndex, int toIndex, + K key, + Function keyExtractor, Comparator comparator + ) { + checkRange(list.size(), fromIndex, toIndex); + + return binarySearchLast0(list, fromIndex, toIndex, key, keyExtractor, comparator); + } + + // common implementation for binarySearch methods, with same semantics: + private static int binarySearchLast0( + List list, + int fromIndex, int toIndex, + K key, + Function keyExtractor, Comparator comparator + ) { + int l = fromIndex; + int h = toIndex - 1; + + while (l <= h) { + final int m = (l + h) >>> 1; // unsigned shift to avoid overflow + final K value = keyExtractor.apply(list.get(m)); + final int c = comparator.compare(value, key); + if (c < 0) { + l = m + 1; + } else if (c > 0) { + h = m - 1; + } else if (m + 1 < h) { + // matching, more than two items remaining: + l = m; + } else if (m + 1 == h) { + // two items remaining, next loops would result in unchanged l and h, we have to choose m or h: + final K valueH = keyExtractor.apply(list.get(h)); + final int cH = comparator.compare(valueH, key); + return cH == 0 ? h : m; + } else { + // one item remaining, single match: + return m; + } + } + + // not found, the l points to the lowest higher match: + return -l - 1; + } + /** * Returns either the passed in list, or if the list is {@code null}, * the value of {@code defaultList}. @@ -741,6 +957,19 @@ public static List unmodifiableList(final List list) { return UnmodifiableList.unmodifiableList(list); } + static void checkRange(int length, int fromIndex, int toIndex) { + if (fromIndex > toIndex) { + throw new IllegalArgumentException( + "fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")"); + } + if (fromIndex < 0) { + throw new ArrayIndexOutOfBoundsException(fromIndex); + } + if (toIndex > length) { + throw new ArrayIndexOutOfBoundsException(toIndex); + } + } + /** * Don't allow instances. */ diff --git a/src/test/java/org/apache/commons/collections4/ListUtilsBinarySearchTest.java b/src/test/java/org/apache/commons/collections4/ListUtilsBinarySearchTest.java new file mode 100644 index 0000000000..c203fc35b3 --- /dev/null +++ b/src/test/java/org/apache/commons/collections4/ListUtilsBinarySearchTest.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.collections4; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrowsExactly; + +/** + * Unit tests {@link ListUtils} binarySearch functions. + */ +public class ListUtilsBinarySearchTest { + + @Test + public void binarySearchFirst_whenLowHigherThanEnd_throw() { + final List list = createList(0, 1); + assertThrowsExactly(IllegalArgumentException.class, () -> + ListUtils.binarySearchFirst(list, 1, 0, 0, Data::getValue, Integer::compare)); + } + + @Test + public void binarySearchFirst_whenLowNegative_throw() { + final List list = createList(0, 1); + assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> + ListUtils.binarySearchFirst(list, -1, 0, 0, Data::getValue, Integer::compare)); + } + + @Test + public void binarySearchFirst_whenEndBeyondLength_throw() { + final List list = createList(0, 1); + assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> + ListUtils.binarySearchFirst(list, 0, 3, 0, Data::getValue, Integer::compare)); + } + + @Test + public void binarySearchLast_whenLowHigherThanEnd_throw() { + final List list = createList(0, 1); + assertThrowsExactly(IllegalArgumentException.class, () -> + ListUtils.binarySearchLast(list, 1, 0, 0, Data::getValue, Integer::compare)); + } + + @Test + public void binarySearchFirst_whenEmpty_returnM1() { + final List list = createList(); + final int found = ListUtils.binarySearchFirst(list, 0, Data::getValue, Integer::compare); + assertEquals(-1, found); + } + + @Test + public void binarySearchFirst_whenExists_returnIndex() { + final List list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); + final int found = ListUtils.binarySearchFirst(list, 9, Data::getValue, Integer::compare); + assertEquals(5, found); + } + + @Test + @Timeout(10) + public void binarySearchFirst_whenMultiple_returnFirst() { + final List list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9); + for (int i = 0; i < list.size(); ++i) { + if (i > 0 && list.get(i).value == list.get(i - 1).value) { + continue; + } + final int found = ListUtils.binarySearchFirst(list, list.get(i).value, Data::getValue, Integer::compare); + assertEquals(i, found); + } + } + + @Test + @Timeout(10) + public void binarySearchLast_whenMultiple_returnFirst() { + final List list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9); + for (int i = 0; i < list.size(); ++i) { + if (i < list.size() - 1 && list.get(i).value == list.get(i + 1).value) { + continue; + } + final int found = ListUtils.binarySearchLast(list, list.get(i).value, Data::getValue, Integer::compare); + assertEquals(i, found); + } + } + + @Test + public void binarySearchFirst_whenNotExistsMiddle_returnMinusInsertion() { + final List list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); + final int found = ListUtils.binarySearchFirst(list, 8, Data::getValue, Integer::compare); + assertEquals(-6, found); + } + + @Test + public void binarySearchFirst_whenNotExistsBeginning_returnMinus1() { + final List list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); + final int found = ListUtils.binarySearchFirst(list, -3, Data::getValue, Integer::compare); + assertEquals(-1, found); + } + + @Test + public void binarySearchFirst_whenNotExistsEnd_returnMinusLength() { + final List list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); + final int found = ListUtils.binarySearchFirst(list, 29, Data::getValue, Integer::compare); + assertEquals(-(list.size() + 1), found); + } + + @Test + @Timeout(10) + public void binarySearchFirst_whenUnsorted_dontInfiniteLoop() { + final List list = createList(7, 1, 4, 9, 11, 8); + final int found = ListUtils.binarySearchFirst(list, 10, Data::getValue, Integer::compare); + } + + private List createList(int... values) { + return IntStream.of(values).mapToObj(Data::new) + .collect(Collectors.toList()); + } + + static class Data { + + private final int value; + + Data(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + } +}