Skip to content

Commit 2ca2ea1

Browse files
Add bulk query APIs (#342)
* Reduce overhead. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reduce overhead. * Add support for array-based bulk insertion. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Consistency checks. * Add support for bulk-query APIs. * Formatting. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7775fdd commit 2ca2ea1

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

rtree/core.py

+38
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,44 @@ def free_error_msg_ptr(result, func, cargs):
238238
rt.Index_NearestNeighbors_id.restype = ctypes.c_int
239239
rt.Index_NearestNeighbors_id.errcheck = check_return # type: ignore
240240

241+
try:
242+
rt.Index_NearestNeighbors_id_v.argtypes = [
243+
ctypes.c_void_p,
244+
ctypes.c_int64,
245+
ctypes.c_int64,
246+
ctypes.c_uint32,
247+
ctypes.c_uint64,
248+
ctypes.c_uint64,
249+
ctypes.c_uint64,
250+
ctypes.c_void_p,
251+
ctypes.c_void_p,
252+
ctypes.c_void_p,
253+
ctypes.c_void_p,
254+
ctypes.c_void_p,
255+
ctypes.POINTER(ctypes.c_int64),
256+
]
257+
rt.Index_NearestNeighbors_id_v.restype = ctypes.c_int
258+
rt.Index_NearestNeighbors_id_v.errcheck = check_return # type: ignore
259+
260+
rt.Index_Intersects_id_v.argtypes = [
261+
ctypes.c_void_p,
262+
ctypes.c_int64,
263+
ctypes.c_uint32,
264+
ctypes.c_uint64,
265+
ctypes.c_uint64,
266+
ctypes.c_uint64,
267+
ctypes.c_void_p,
268+
ctypes.c_void_p,
269+
ctypes.c_void_p,
270+
ctypes.c_void_p,
271+
ctypes.POINTER(ctypes.c_int64),
272+
]
273+
rt.Index_Intersects_id_v.restype = ctypes.c_int
274+
rt.Index_Intersects_id_v.errcheck = check_return # type: ignore
275+
except AttributeError:
276+
pass
277+
278+
241279
rt.Index_GetLeaves.argtypes = [
242280
ctypes.c_void_p,
243281
ctypes.POINTER(ctypes.c_uint32),

rtree/index.py

+109
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,108 @@ def nearest(
10461046

10471047
return self._get_ids(it, p_num_results.contents.value)
10481048

1049+
def intersection_v(self, mins, maxs):
1050+
import numpy as np
1051+
1052+
assert mins.shape == maxs.shape
1053+
assert mins.strides == maxs.strides
1054+
1055+
# Cast
1056+
mins = mins.astype(np.float64)
1057+
maxs = maxs.astype(np.float64)
1058+
1059+
# Extract counts
1060+
n, d = mins.shape
1061+
1062+
# Compute strides
1063+
d_i_stri = mins.strides[0] // mins.itemsize
1064+
d_j_stri = mins.strides[1] // mins.itemsize
1065+
1066+
ids = np.empty(2 * n, dtype=np.int64)
1067+
counts = np.empty(n, dtype=np.uint64)
1068+
nr = ctypes.c_int64(0)
1069+
offn, offi = 0, 0
1070+
1071+
while True:
1072+
core.rt.Index_Intersects_id_v(
1073+
self.handle,
1074+
n - offn,
1075+
d,
1076+
len(ids),
1077+
d_i_stri,
1078+
d_j_stri,
1079+
mins[offn:].ctypes.data,
1080+
maxs[offn:].ctypes.data,
1081+
ids[offi:].ctypes.data,
1082+
counts[offn:].ctypes.data,
1083+
ctypes.byref(nr),
1084+
)
1085+
1086+
# If we got the expected nuber of results then return
1087+
if nr.value == n - offn:
1088+
return ids[: counts.sum()], counts
1089+
# Otherwise, if our array is too small then resize
1090+
else:
1091+
offi += counts[offn : offn + nr.value].sum()
1092+
offn += nr.value
1093+
1094+
ids = ids.resize(2 * len(ids), refcheck=False)
1095+
1096+
def nearest_v(
1097+
self, mins, maxs, num_results=1, strict=False, return_max_dists=False
1098+
):
1099+
import numpy as np
1100+
1101+
assert mins.shape == maxs.shape
1102+
assert mins.strides == maxs.strides
1103+
1104+
# Cast
1105+
mins = mins.astype(np.float64)
1106+
maxs = maxs.astype(np.float64)
1107+
1108+
# Extract counts
1109+
n, d = mins.shape
1110+
1111+
# Compute strides
1112+
d_i_stri = mins.strides[0] // mins.itemsize
1113+
d_j_stri = mins.strides[1] // mins.itemsize
1114+
1115+
ids = np.empty(n * num_results, dtype=np.int64)
1116+
counts = np.empty(n, dtype=np.uint64)
1117+
dists = np.empty(n) if return_max_dists else None
1118+
nr = ctypes.c_int64(0)
1119+
offn, offi = 0, 0
1120+
1121+
while True:
1122+
core.rt.Index_NearestNeighbors_id_v(
1123+
self.handle,
1124+
num_results if not strict else -num_results,
1125+
n - offn,
1126+
d,
1127+
len(ids),
1128+
d_i_stri,
1129+
d_j_stri,
1130+
mins[offn:].ctypes.data,
1131+
maxs[offn:].ctypes.data,
1132+
ids[offi:].ctypes.data,
1133+
counts[offn:].ctypes.data,
1134+
dists[offn:].ctypes.data if return_max_dists else None,
1135+
ctypes.byref(nr),
1136+
)
1137+
1138+
# If we got the expected nuber of results then return
1139+
if nr.value == n - offn:
1140+
if return_max_dists:
1141+
return ids[: counts.sum()], counts, dists
1142+
else:
1143+
return ids[: counts.sum()], counts
1144+
# Otherwise, if our array is too small then resize
1145+
else:
1146+
offi += counts[offn : offn + nr.value].sum()
1147+
offn += nr.value
1148+
1149+
ids = ids.resize(2 * len(ids), refcheck=False)
1150+
10491151
def _nearestTP(self, coordinates, velocities, times, num_results=1, objects=False):
10501152
p_mins, p_maxs = self.get_coordinate_pointers(coordinates)
10511153
pv_mins, pv_maxs = self.get_coordinate_pointers(velocities)
@@ -1538,6 +1640,13 @@ def initialize_from_dict(self, state: dict[str, Any]) -> None:
15381640
if v is not None:
15391641
setattr(self, k, v)
15401642

1643+
# Consistency checks
1644+
if "near_minimum_overlap_factor" not in state:
1645+
nmof = self.near_minimum_overlap_factor
1646+
ilc = min(self.index_capacity, self.leaf_capacity)
1647+
if nmof >= ilc:
1648+
self.near_minimum_overlap_factor = ilc // 3 + 1
1649+
15411650
def __getstate__(self) -> dict[Any, Any]:
15421651
return self.as_dict()
15431652

0 commit comments

Comments
 (0)