diff --git a/projectaria_tools/core/mps/utils.py b/projectaria_tools/core/mps/utils.py index 1248559a2..2d3845a80 100644 --- a/projectaria_tools/core/mps/utils.py +++ b/projectaria_tools/core/mps/utils.py @@ -32,26 +32,38 @@ def bisection_timestamp_search(timed_data, query_timestamp_ns: int) -> int: Returns index of the element closest to the query timestamp else returns None if not found (out of time range) """ # Deal with border case - if timed_data and len(timed_data) > 1: - first_timestamp = timed_data[0].tracking_timestamp.total_seconds() * 1e9 - last_timestamp = timed_data[-1].tracking_timestamp.total_seconds() * 1e9 - if query_timestamp_ns <= first_timestamp: - return None - elif query_timestamp_ns >= last_timestamp: - return None - # If this is safe we perform the Bisection search - start = 0 - end = len(timed_data) - 1 + if not timed_data or len(timed_data) < 2: + return None + + # Convert the first and last timestamps for range checks + first_timestamp = timed_data[0].tracking_timestamp.total_seconds() * 1e9 + last_timestamp = timed_data[-1].tracking_timestamp.total_seconds() * 1e9 + + # Handle out-of-range cases + if query_timestamp_ns <= first_timestamp or query_timestamp_ns >= last_timestamp: + return None + + # Perform binary search + start, end = 0, len(timed_data) - 1 while start < end: mid = (start + end) // 2 mid_timestamp = timed_data[mid].tracking_timestamp.total_seconds() * 1e9 + if mid_timestamp == query_timestamp_ns: return mid - if mid_timestamp < query_timestamp_ns: + elif mid_timestamp < query_timestamp_ns: start = mid + 1 else: end = mid - 1 - return start + + # Post-loop adjustment: Compare start and its neighbor + prev_index = max(0, start - 1) + next_index = min(len(timed_data) - 1, start) + + prev_diff = abs(timed_data[prev_index].tracking_timestamp.total_seconds() * 1e9 - query_timestamp_ns) + next_diff = abs(timed_data[next_index].tracking_timestamp.total_seconds() * 1e9 - query_timestamp_ns) + + return prev_index if prev_diff <= next_diff else next_index def get_nearest_eye_gaze(eye_gazes: List[EyeGaze], query_timestamp_ns: int) -> EyeGaze: