diff --git a/movement/kinematics.py b/movement/kinematics.py index f74b50c13..d64676db0 100644 --- a/movement/kinematics.py +++ b/movement/kinematics.py @@ -948,3 +948,59 @@ def _compute_scaled_path_length( valid_proportion = valid_segments / (data.sizes["time"] - 1) # return scaled path length return compute_norm(displacement).sum(dim="time") / valid_proportion + + +def detect_u_turns( + data: xr.DataArray, + use_direction: Literal["forward_vector", "displacement"] = "displacement", + u_turn_threshold: float = np.pi * 5 / 6, # 150 degrees in radians + camera_view: Literal["top_down", "bottom_up"] = "bottom_up", +) -> xr.DataArray: + """Detect U-turn behavior in a trajectory. + + This function computes the directional change between consecutive time + frames and accumulates the rotation angles. If the accumulated angle + exceeds a specified threshold, a U-turn is detected. + + Parameters + ---------- + data : xarray.DataArray + The trajectory data, which must contain the 'time' and 'space' (x, y). + use_direction : Literal["forward_vector", "displacement"], optional + Method to compute direction vectors, default is `"displacement"`: + - `"forward_vector"`: Computes the forward direction vector. + - `"displacement"`: Computes displacement vectors. + u_turn_threshold : float, optional + The angle threshold (in radians) to detect U-turn. Default is (`5π/6`). + + camera_view : Literal["top_down", "bottom_up"], optional + Specifies the camera perspective used for computing direction vectors. + + Returns + ------- + xarray.DataArray + Indicating whether a U-turn has occurred (`True` for a U-turn). + + """ + # Compute direction vectors based on the chosen method + if use_direction == "forward_vector": + direction_vectors = compute_forward_vector( + data, "left_ear", "right_ear", camera_view=camera_view + ) + elif use_direction == "displacement": + direction_vectors = compute_displacement(data) + else: + raise ValueError( + "The parameter `use_direction` must be one of `forward_vector` " + f" or `displacement`, but got {use_direction}." + ) + + angles = compute_signed_angle_2d( + direction_vectors.shift(time=1), direction_vectors + ) + cumulative_rotation = angles.cumsum(dim="time") + rotation_range = cumulative_rotation.max( + dim="time" + ) - cumulative_rotation.min(dim="time") + u_turn_detected = rotation_range >= u_turn_threshold + return u_turn_detected diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 30493a5fc..6a7df0852 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -913,3 +913,80 @@ def test_casts_from_tuple( xr.testing.assert_allclose(pass_numpy, pass_tuple) xr.testing.assert_allclose(pass_numpy, pass_list) + + +@pytest.fixture +def valid_data_array_for_u_turn_detection(): + """Return a position data array for an individual with 3 keypoints + (left ear, right ear, and nose), tracked for 4 frames, in x-y space. + """ + time = [0, 1, 2, 3] + keypoints = ["left_ear", "right_ear", "nose"] + space = ["x", "y"] + + ds = xr.DataArray( + [ + [[-1, 0], [1, 0], [0, 1]], # time 0 + [[0, 2], [0, 0], [1, 1]], # time 1 + [[2, 1], [0, 1], [1, 0]], # time 2 + [[1, -1], [1, 1], [0, 0]], # time 3 + ], + dims=["time", "keypoints", "space"], + coords={ + "time": time, + "keypoints": keypoints, + "space": space, + }, + ) + return ds + + +def test_detect_u_turns(valid_data_array_for_u_turn_detection): + """Test that U-turn detection works correctly using a mock dataset.""" + u_turn_forward_vector = kinematics.detect_u_turns( + valid_data_array_for_u_turn_detection, use_direction="forward_vector" + ) + nose_data = valid_data_array_for_u_turn_detection.sel( + keypoints="nose" + ).drop("keypoints") + u_turn_displacement = kinematics.detect_u_turns( + nose_data, use_direction="displacement" + ) + + # Known expected U-turn detection results + known_u_turn_displacement = np.array( + [True] + ) # Example expected result for displacement + known_u_turn_forward_vector = np.array( + [True] + ) # Example expected result for forward_vector + + assert np.all(u_turn_displacement.values == known_u_turn_displacement) + assert np.all(u_turn_forward_vector.values == known_u_turn_forward_vector) + + u_turn_forward_vector = kinematics.detect_u_turns( + valid_data_array_for_u_turn_detection, + use_direction="forward_vector", + u_turn_threshold=np.pi * 7 / 6, + ) + nose_data = valid_data_array_for_u_turn_detection.sel( + keypoints="nose" + ).drop("keypoints") + u_turn_displacement = kinematics.detect_u_turns( + nose_data, use_direction="displacement", u_turn_threshold=np.pi * 7 / 6 + ) + known_u_turn_displacement = np.array([False]) + known_u_turn_forward_vector = np.array([True]) + assert np.all(u_turn_displacement.values == known_u_turn_displacement) + assert np.all(u_turn_forward_vector.values == known_u_turn_forward_vector) + with pytest.raises( + ValueError, + match=( + "The parameter `use_direction` must be one of `forward_vector` " + " or `displacement`, but got invalid_direction." + ), + ): + kinematics.detect_u_turns( + valid_data_array_for_u_turn_detection, + use_direction="invalid_direction", + )