diff --git a/examples/trajectory_complexity.py b/examples/trajectory_complexity.py new file mode 100644 index 000000000..e1da2eddc --- /dev/null +++ b/examples/trajectory_complexity.py @@ -0,0 +1,208 @@ +"""Trajectory complexity measures for animal movement paths. +========================================================== + +Compute and visualize various trajectory complexity measures +including straightness index, sinuosity, tortuosity, and more. +""" + +# %% +# Imports +# ------- + +# For interactive plots: install ipympl with `pip install ipympl` and uncomment +# the following line in your notebook +# %matplotlib widget +import numpy as np +from matplotlib import pyplot as plt + +from movement import sample_data +from movement.plots import plot_centroid_trajectory +from movement.trajectory_complexity import ( + compute_angular_velocity, + compute_directional_change, + compute_sinuosity, + compute_straightness_index, + compute_tortuosity, +) + +# %% +# Load sample dataset +# ------------------ +# First, we load an example dataset. In this case, we select the +# ``SLEAP_three-mice_Aeon_proofread`` sample data. +ds = sample_data.fetch_dataset( + "SLEAP_three-mice_Aeon_proofread.analysis.h5", +) + +print(ds) + +# %% +# We'll use the position data for our trajectory complexity analysis +position = ds.position + +# %% +# Plot trajectories +# ---------------- +# First, let's visualize the trajectories of the mice in the XY plane, +# to get a sense of their movement patterns. + +fig, ax = plt.subplots(1, 1, figsize=(10, 8)) +plot_centroid_trajectory( + ds, + ax=ax, + color_by="individual", + plot_markers=False, + alpha=0.7, +) +ax.set_title("Mouse Trajectories") +ax.invert_yaxis() # Make y-axis match image coordinates (0 at top) +plt.tight_layout() + +# %% +# Straightness Index +# ---------------- +# The straightness index is a simple measure of path complexity, defined as the +# ratio of the straight-line distance between start and end points to the total +# path length. Values closer to 1 indicate straighter paths. + +straightness = compute_straightness_index(position) +print("Straightness index by individual:") +for ind in straightness.individual.values: + print(f" {ind}: {straightness.sel(individual=ind).item():.3f}") + +# %% +# Sinuosity +# -------- +# Sinuosity provides a local measure of path complexity using a sliding window. +# It's essentially the inverse of straightness - higher values indicate more +# tortuous paths. + +sinuosity = compute_sinuosity(position, window_size=20) + +# Plot sinuosity over time for each individual +fig, ax = plt.subplots(1, 1, figsize=(12, 6)) +for ind in sinuosity.individual.values: + ax.plot(sinuosity.time, sinuosity.sel(individual=ind), label=ind) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Sinuosity") +ax.set_title("Sinuosity over time (window size = 20 frames)") +ax.legend() +plt.tight_layout() + +# %% +# Angular Velocity +# -------------- +# Angular velocity measures the rate of change in direction. Higher values +# indicate sharper turns or changes in direction. + +ang_vel = compute_angular_velocity(position, in_degrees=True) + +# Plot angular velocity over time for each individual +fig, ax = plt.subplots(1, 1, figsize=(12, 6)) +for ind in ang_vel.individual.values: + ax.plot( + ang_vel.time, + ang_vel.sel(individual=ind), + label=ind, + alpha=0.7, + ) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Angular velocity (degrees)") +ax.set_title("Angular velocity over time") +ax.legend() +plt.tight_layout() + +# %% +# Tortuosity +# --------- +# Tortuosity measures the degree of winding or twisting of a path. +# Here we use two different methods: fractal dimension and angular variance. + +# Compute tortuosity using angular variance method +tort_ang = compute_tortuosity(position, method="angular_variance") +print("Tortuosity (angular variance) by individual:") +for ind in tort_ang.individual.values: + print(f" {ind}: {tort_ang.sel(individual=ind).item():.3f}") + +# Compute tortuosity using fractal dimension method +tort_frac = compute_tortuosity(position, method="fractal") +print("\nTortuosity (fractal dimension) by individual:") +for ind in tort_frac.individual.values: + print(f" {ind}: {tort_frac.sel(individual=ind).item():.3f}") + +# %% +# Directional Change +# ---------------- +# Directional change measures the total amount of turning within a window. +# Higher values indicate more meandering behavior. + +dir_change = compute_directional_change( + position, window_size=20, in_degrees=True +) + +# Plot directional change over time for each individual +fig, ax = plt.subplots(1, 1, figsize=(12, 6)) +for ind in dir_change.individual.values: + ax.plot( + dir_change.time, + dir_change.sel(individual=ind), + label=ind, + alpha=0.7, + ) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Directional change (degrees)") +ax.set_title("Directional change over time (window size = 20 frames)") +ax.legend() +plt.tight_layout() + +# %% +# Compare measures across individuals +# --------------------------------- +# Let's create a summary bar plot to compare different trajectory complexity measures +# across individuals. + +# Collect measures for each individual +individuals = position.individual.values +measures = { + "Straightness Index": straightness, + "Tortuosity (Angular)": tort_ang, + "Tortuosity (Fractal)": tort_frac, +} + +# Create bar plot +fig, ax = plt.subplots(1, 1, figsize=(10, 6)) +x = np.arange(len(individuals)) +width = 0.25 +multiplier = 0 + +for measure_name, measure_data in measures.items(): + offset = width * multiplier + rects = ax.bar( + x + offset, + [measure_data.sel(individual=ind).item() for ind in individuals], + width, + label=measure_name, + ) + multiplier += 1 + +ax.set_xticks(x + width, individuals) +ax.set_ylabel("Value") +ax.set_title("Comparison of Trajectory Complexity Measures") +ax.legend(loc="upper left", bbox_to_anchor=(1, 1)) +plt.tight_layout() + +# %% +# Conclusion +# --------- +# These trajectory complexity measures provide different ways to quantify and +# compare animal movement patterns. The choice of measure depends on the specific +# research question: +# +# - **Straightness Index**: Best for overall path directness +# - **Sinuosity**: Good for local path complexity that varies over time +# - **Angular Velocity**: Useful for identifying sharp turns or directional changes +# - **Tortuosity**: Captures the overall winding nature of the path +# - **Directional Change**: Quantifies turning behavior within a time window +# +# By combining these measures, researchers can gain insights into various aspects +# of animal movement behavior. diff --git a/movement/__init__.py b/movement/__init__.py index bf5d4a2d2..385598e6e 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -15,3 +15,12 @@ # initialize logger upon import configure_logging() + +# Import trajectory complexity functions to make them available at package level +from movement.trajectory_complexity import ( + compute_straightness_index, + compute_sinuosity, + compute_tortuosity, + compute_angular_velocity, + compute_directional_change, +) diff --git a/movement/trajectory_complexity.py b/movement/trajectory_complexity.py new file mode 100644 index 000000000..5e18ab80d --- /dev/null +++ b/movement/trajectory_complexity.py @@ -0,0 +1,487 @@ +"""Compute trajectory complexity measures. + +This module provides functions to compute various measures of trajectory +complexity, which quantify how straight or tortuous a path is. These metrics +are useful for analyzing animal movement patterns across space. +""" + +from typing import Literal + +import numpy as np +import xarray as xr + +from movement.kinematics import compute_displacement, compute_path_length +from movement.utils.logging import log_error, log_to_attrs, log_warning +from movement.utils.vector import compute_norm +from movement.validators.arrays import validate_dims_coords + + +@log_to_attrs +def compute_straightness_index( + data: xr.DataArray, + start: float | None = None, + stop: float | None = None, +) -> xr.DataArray: + """Compute the straightness index of a trajectory. + + The straightness index is defined as the ratio of the Euclidean distance + between the start and end points of a trajectory to the total path length. + Values range from 0 to 1, where 1 indicates a perfectly straight path, + and values closer to 0 indicate more tortuous paths. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + start : float, optional + The start time of the trajectory. If None (default), + the minimum time coordinate in the data is used. + stop : float, optional + The end time of the trajectory. If None (default), + the maximum time coordinate in the data is used. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed straightness index, + with dimensions matching those of the input data, + except ``time`` and ``space`` are removed. + + Notes + ----- + The straightness index (SI) is calculated as: + + SI = Euclidean distance / Path length + + where the Euclidean distance is the straight-line distance between the + start and end points, and the path length is the total distance traveled + along the trajectory. + + References + ---------- + .. [1] Batschelet, E. (1981). Circular statistics in biology. + London: Academic Press. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + + # Determine start and stop points + if start is None: + start = data.time.min().item() + if stop is None: + stop = data.time.max().item() + + # Extract start and end positions + start_pos = data.sel(time=start, method="nearest") + end_pos = data.sel(time=stop, method="nearest") + + # Calculate Euclidean distance between start and end points + euclidean_distance = compute_norm(end_pos - start_pos) + + # Calculate path length + path_length = compute_path_length(data, start=start, stop=stop) + + # Compute straightness index + straightness_index = euclidean_distance / path_length + + return straightness_index + + +@log_to_attrs +def compute_sinuosity( + data: xr.DataArray, + window_size: int = 10, + stride: int = 1, +) -> xr.DataArray: + """Compute the sinuosity of a trajectory using a sliding window. + + Sinuosity is computed as the ratio of the path length to the Euclidean distance + between the start and end points, within each window. This provides a + local measure of trajectory complexity that varies along the path. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + window_size : int, optional + The size of the sliding window in number of time points. + Default is 10. + stride : int, optional + The step size for the sliding window. Default is 1. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed sinuosity at each time point, + with dimensions matching those of the input data, + except ``space`` is removed. + + Notes + ----- + The sinuosity is essentially the inverse of the straightness index. + Values range from 1 to infinity, where 1 indicates a perfectly straight path, + and higher values indicate more tortuous paths. + + References + ---------- + .. [1] Benhamou, S. (2004). How to reliably estimate the tortuosity of + an animal's path: straightness, sinuosity, or fractal dimension? + Journal of Theoretical Biology, 229(2), 209-220. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + + # Validate window_size + if window_size < 2: + raise log_error( + ValueError, + "window_size must be at least 2 to compute sinuosity.", + ) + + # Get number of time points + n_time = data.sizes["time"] + + # Initialize result array with NaNs + result = xr.full_like(data.isel(space=0), fill_value=np.nan) + + # Calculate sinuosity for each window + for i in range(0, n_time - window_size + 1, stride): + # Extract window data + window_data = data.isel(time=slice(i, i + window_size)) + + # Extract start and end positions + start_pos = window_data.isel(time=0) + end_pos = window_data.isel(time=-1) + + # Calculate Euclidean distance between start and end points + euclidean_distance = compute_norm(end_pos - start_pos) + + # Calculate path length within window + displacements = compute_displacement(window_data).isel( + time=slice(1, None) + ) + path_length = compute_norm(displacements).sum(dim="time") + + # Compute sinuosity (inverse of straightness) + sinuosity = path_length / euclidean_distance + + # Assign to middle point of window + mid_idx = i + window_size // 2 + result.isel(time=mid_idx).data = sinuosity.data + + return result + + +@log_to_attrs +def compute_angular_velocity( + data: xr.DataArray, + in_degrees: bool = False, +) -> xr.DataArray: + """Compute the angular velocity of a trajectory. + + Angular velocity measures the rate of change of the angle of movement. + It is computed as the angle between consecutive displacement vectors. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + in_degrees : bool, optional + Whether to return the result in degrees (True) or radians (False). + Default is False. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed angular velocity, + with dimensions matching those of the input data, + except ``space`` is removed. + + Notes + ----- + Angular velocity is defined as the angle between consecutive displacement + vectors divided by the time interval. High angular velocities indicate + sharp turns or changes in direction. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + + # Compute displacement vectors + displacement = compute_displacement(data) + + # Skip first time point (displacement is 0) + displacement = displacement.isel(time=slice(1, None)) + + # Compute unit vectors (normalize displacement) + unit_displacement = displacement / compute_norm(displacement).fillna(1) + + # Compute dot products between consecutive unit vectors + dot_products = ( + unit_displacement.isel(time=slice(1, None)) + * unit_displacement.isel(time=slice(0, -1)) + ).sum(dim="space") + + # Clip dot products to [-1, 1] to handle numerical errors + dot_products = xr.where(dot_products > 1, 1, dot_products) + dot_products = xr.where(dot_products < -1, -1, dot_products) + + # Compute angles in radians + angles = np.arccos(dot_products) + + # Convert to degrees if requested + if in_degrees: + angles = np.rad2deg(angles) + + # Create result array with same dimensions as input but with NaNs at endpoints + result = xr.full_like(data.isel(space=0), fill_value=np.nan) + + # Assign computed angles to result (offset by 1 to account for displacement calculation) + result.isel(time=slice(2, None)).data = angles.data + + return result + + +@log_to_attrs +def compute_tortuosity( + data: xr.DataArray, + start: float | None = None, + stop: float | None = None, + method: Literal["fractal", "angular_variance"] = "angular_variance", + window_size: int = 10, +) -> xr.DataArray: + """Compute the tortuosity of a trajectory. + + Tortuosity is a measure of the degree of winding or twisting of a path. + This function provides multiple methods to compute tortuosity. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + start : float, optional + The start time of the trajectory. If None (default), + the minimum time coordinate in the data is used. + stop : float, optional + The end time of the trajectory. If None (default), + the maximum time coordinate in the data is used. + method : Literal["fractal", "angular_variance"], optional + The method to use for computing tortuosity. + "fractal" uses box-counting fractal dimension. + "angular_variance" uses the circular variance of turning angles. + Default is "angular_variance". + window_size : int, optional + The size of the window used for the fractal method. + Default is 10. Only used if method="fractal". + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed tortuosity, + with dimensions matching those of the input data, + except ``time`` and ``space`` are removed. + + Notes + ----- + The "fractal" method estimates the fractal dimension of the path using + the box-counting method. It ranges from 1 (straight line) to 2 (completely + space-filling curve). + + The "angular_variance" method computes the circular variance of turning + angles. It ranges from 0 (straight line) to 1 (highly tortuous path with + uniformly distributed turning angles). + + References + ---------- + .. [1] Nams, V. O. (1996). The VFractal: a new estimator for fractal + dimension of animal movement paths. Landscape Ecology, 11(5), 289-297. + .. [2] Benhamou, S. (2004). How to reliably estimate the tortuosity of + an animal's path: straightness, sinuosity, or fractal dimension? + Journal of Theoretical Biology, 229(2), 209-220. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + + # Determine start and stop points + if start is None: + start = data.time.min().item() + if stop is None: + stop = data.time.max().item() + + # Filter data to desired time range + data_filtered = data.sel(time=slice(start, stop)) + + if method == "angular_variance": + # Compute displacement vectors + displacement = compute_displacement(data_filtered) + + # Skip first time point (displacement is 0) + displacement = displacement.isel(time=slice(1, None)) + + # Compute unit vectors (normalize displacement) + unit_displacement = displacement / compute_norm(displacement).fillna(1) + + # Compute dot products between consecutive unit vectors + dot_products = ( + unit_displacement.isel(time=slice(1, None)) + * unit_displacement.isel(time=slice(0, -1)) + ).sum(dim="space") + + # Clip dot products to [-1, 1] to handle numerical errors + dot_products = xr.where(dot_products > 1, 1, dot_products) + dot_products = xr.where(dot_products < -1, -1, dot_products) + + # Compute angles in radians + angles = np.arccos(dot_products) + + # Compute circular mean of cosines and sines of angles + mean_cos = np.cos(angles).mean(dim="time") + mean_sin = np.sin(angles).mean(dim="time") + + # Compute circular variance (R = 1 - mean resultant length) + R = 1 - np.sqrt(mean_cos**2 + mean_sin**2) + + return R + + elif method == "fractal": + # Implementing a simplified box-counting fractal dimension + if len(data_filtered.space) != 2: + raise log_error( + ValueError, + "The fractal dimension method only works with 2D data.", + ) + + # Get x and y coordinates + x = data_filtered.sel(space="x").values + y = data_filtered.sel(space="y").values + + # Normalize to [0, 1] range + x_norm = (x - np.min(x)) / (np.max(x) - np.min(x)) + y_norm = (y - np.min(y)) / (np.max(y) - np.min(y)) + + # Initialize arrays for box counts + scales = [] + counts = [] + + # Calculate box counts at different scales + for scale in range(2, min(64, len(x_norm) // 4)): + # Create grid + grid_size = 1.0 / scale + occupied_boxes = set() + + # Count occupied boxes + for i in range(len(x_norm)): + box_x = int(x_norm[i] / grid_size) + box_y = int(y_norm[i] / grid_size) + occupied_boxes.add((box_x, box_y)) + + scales.append(scale) + counts.append(len(occupied_boxes)) + + if len(scales) < 2: + log_warning( + "Not enough data points to compute fractal dimension. Returning NaN." + ) + result = xr.full_like( + data_filtered.isel(time=0, space=0, drop=True), + fill_value=np.nan, + ) + return result + + # Compute fractal dimension as the slope of log(count) vs log(scale) + log_scales = np.log(scales) + log_counts = np.log(counts) + + # Linear regression: log(count) = D * log(scale) + b + D = np.polyfit(log_scales, log_counts, 1)[0] + + # Create result array + # Remove time and space dimensions + dims = [ + dim for dim in data_filtered.dims if dim not in ["time", "space"] + ] + coords = {dim: data_filtered[dim] for dim in dims} + + result = xr.DataArray( + D, + dims=dims, + coords=coords, + ) + + return result + else: + raise log_error( + ValueError, + f"Unknown method: {method}. Use 'fractal' or 'angular_variance'.", + ) + + +@log_to_attrs +def compute_directional_change( + data: xr.DataArray, + window_size: int = 10, + in_degrees: bool = False, +) -> xr.DataArray: + """Compute the directional change along a trajectory. + + Directional change measures the total amount of turning within a window, + calculated as the sum of absolute angular changes between consecutive + displacement vectors. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + window_size : int, optional + The size of the sliding window in number of time points. + Default is 10. + in_degrees : bool, optional + Whether to return the result in degrees (True) or radians (False). + Default is False. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed directional change at each time point, + with dimensions matching those of the input data, + except ``space`` is removed. + + Notes + ----- + Directional change is calculated as the sum of absolute angular changes + within a sliding window. Higher values indicate more turning or meandering + behavior, while lower values indicate more directed movement. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + + # Compute angular velocity (in radians) + angular_velocity = compute_angular_velocity(data, in_degrees=False) + + # Get number of time points + n_time = data.sizes["time"] + + # Initialize result array with NaNs + result = xr.full_like(data.isel(space=0), fill_value=np.nan) + + # Calculate directional change for each window + for i in range(0, n_time - window_size + 1): + # Extract window data + window_data = angular_velocity.isel(time=slice(i, i + window_size)) + + # Sum absolute angular changes + directional_change = np.nansum(np.abs(window_data.values)) + + # Convert to degrees if requested + if in_degrees: + directional_change = np.rad2deg(directional_change) + + # Assign to middle point of window + mid_idx = i + window_size // 2 + result.isel(time=mid_idx).data = directional_change + + return result diff --git a/tests/test_unit/test_trajectory_complexity.py b/tests/test_unit/test_trajectory_complexity.py new file mode 100644 index 000000000..6fb62c8a9 --- /dev/null +++ b/tests/test_unit/test_trajectory_complexity.py @@ -0,0 +1,204 @@ +"""Unit tests for the trajectory complexity module.""" + +import numpy as np +import pytest +import xarray as xr + +from movement.trajectory_complexity import ( + compute_angular_velocity, + compute_directional_change, + compute_sinuosity, + compute_straightness_index, + compute_tortuosity, +) + + +@pytest.fixture +def straight_trajectory(): + """Create a straight line trajectory.""" + position = np.zeros((20, 2, 1, 1)) + # x-coordinate increases linearly from 0 to 19 + position[:, 0, 0, 0] = np.arange(20) + # y-coordinate stays at 0 + position[:, 1, 0, 0] = 0 + + return xr.DataArray( + position, + dims=["time", "space", "keypoints", "individual"], + coords={ + "time": np.arange(20) / 10, # time in seconds + "space": ["x", "y"], + "keypoints": ["centroid"], + "individual": ["test_subject"], + }, + ) + + +@pytest.fixture +def zigzag_trajectory(): + """Create a zigzag trajectory.""" + position = np.zeros((20, 2, 1, 1)) + # x-coordinate increases linearly from 0 to 19 + position[:, 0, 0, 0] = np.arange(20) + # y-coordinate alternates between -1 and 1 + position[:, 1, 0, 0] = np.sin(np.arange(20) * np.pi / 2) + + return xr.DataArray( + position, + dims=["time", "space", "keypoints", "individual"], + coords={ + "time": np.arange(20) / 10, # time in seconds + "space": ["x", "y"], + "keypoints": ["centroid"], + "individual": ["test_subject"], + }, + ) + + +@pytest.fixture +def circular_trajectory(): + """Create a circular trajectory.""" + position = np.zeros((20, 2, 1, 1)) + # x and y follow a circular path + t = np.linspace(0, 2 * np.pi, 20) + position[:, 0, 0, 0] = 5 * np.cos(t) + 10 # center at x=10 + position[:, 1, 0, 0] = 5 * np.sin(t) + 10 # center at y=10 + + return xr.DataArray( + position, + dims=["time", "space", "keypoints", "individual"], + coords={ + "time": np.arange(20) / 10, # time in seconds + "space": ["x", "y"], + "keypoints": ["centroid"], + "individual": ["test_subject"], + }, + ) + + +def test_straightness_index_straight_line(straight_trajectory): + """Test that a straight line has straightness index close to 1.""" + result = compute_straightness_index(straight_trajectory) + # Should be very close to 1 for a straight line + assert ( + result.sel(keypoints="centroid", individual="test_subject").item() + > 0.99 + ) + + +def test_straightness_index_zigzag(zigzag_trajectory): + """Test that a zigzag path has straightness index less than 1.""" + result = compute_straightness_index(zigzag_trajectory) + # Should be less than 1 for a zigzag path + assert ( + result.sel(keypoints="centroid", individual="test_subject").item() + < 0.9 + ) + + +def test_straightness_index_circle(circular_trajectory): + """Test that a circular path that returns to start has low straightness.""" + result = compute_straightness_index(circular_trajectory) + # Should be very low for a circle that nearly returns to starting point + assert ( + result.sel(keypoints="centroid", individual="test_subject").item() + < 0.2 + ) + + +def test_sinuosity_straight_line(straight_trajectory): + """Test that a straight line has sinuosity close to 1.""" + result = compute_sinuosity(straight_trajectory, window_size=5) + # Take the mean sinuosity over valid time points + mean_sinuosity = result.sel( + keypoints="centroid", individual="test_subject" + ).mean(skipna=True) + # Should be very close to 1 for a straight line + assert mean_sinuosity.item() < 1.1 + + +def test_sinuosity_zigzag(zigzag_trajectory): + """Test that a zigzag path has sinuosity greater than 1.""" + result = compute_sinuosity(zigzag_trajectory, window_size=5) + # Take the mean sinuosity over valid time points + mean_sinuosity = result.sel( + keypoints="centroid", individual="test_subject" + ).mean(skipna=True) + # Should be greater than 1 for a zigzag path + assert mean_sinuosity.item() > 1.1 + + +def test_angular_velocity_straight_line(straight_trajectory): + """Test that a straight line has angular velocity close to 0.""" + result = compute_angular_velocity(straight_trajectory) + # All angular velocities should be close to zero for a straight line + max_ang_vel = np.nanmax( + result.sel(keypoints="centroid", individual="test_subject").values + ) + assert max_ang_vel < 1e-10 + + +def test_angular_velocity_zigzag(zigzag_trajectory): + """Test that a zigzag path has non-zero angular velocity.""" + result = compute_angular_velocity(zigzag_trajectory) + # Should have some non-zero angular velocities + max_ang_vel = np.nanmax( + result.sel(keypoints="centroid", individual="test_subject").values + ) + assert max_ang_vel > 0.1 + + +def test_tortuosity_angular_variance_straight(straight_trajectory): + """Test that a straight line has low angular variance tortuosity.""" + result = compute_tortuosity(straight_trajectory, method="angular_variance") + # Angular variance should be very close to 0 for a straight line + assert ( + result.sel(keypoints="centroid", individual="test_subject").item() + < 0.1 + ) + + +def test_tortuosity_angular_variance_zigzag(zigzag_trajectory): + """Test that a zigzag path has higher angular variance tortuosity.""" + result = compute_tortuosity(zigzag_trajectory, method="angular_variance") + # Angular variance should be higher for a zigzag path + assert ( + result.sel(keypoints="centroid", individual="test_subject").item() + > 0.1 + ) + + +def test_tortuosity_fractal_straight(straight_trajectory): + """Test that a straight line has fractal dimension close to 1.""" + result = compute_tortuosity(straight_trajectory, method="fractal") + # Fractal dimension should be close to 1 for a straight line + tort = result.sel(keypoints="centroid", individual="test_subject").item() + assert 0.9 < tort < 1.1 + + +def test_tortuosity_fractal_zigzag(zigzag_trajectory): + """Test that a zigzag path has fractal dimension greater than 1.""" + result = compute_tortuosity(zigzag_trajectory, method="fractal") + # Fractal dimension should be greater than 1 for a zigzag path + tort = result.sel(keypoints="centroid", individual="test_subject").item() + assert tort > 1.1 + + +def test_directional_change_straight(straight_trajectory): + """Test that a straight line has directional change close to 0.""" + result = compute_directional_change(straight_trajectory, window_size=5) + # Sum of angular changes should be close to 0 for a straight line + mean_dir_change = result.sel( + keypoints="centroid", individual="test_subject" + ).mean(skipna=True) + assert mean_dir_change.item() < 1e-10 + + +def test_directional_change_zigzag(zigzag_trajectory): + """Test that a zigzag path has non-zero directional change.""" + result = compute_directional_change(zigzag_trajectory, window_size=5) + # Sum of angular changes should be non-zero for a zigzag path + mean_dir_change = result.sel( + keypoints="centroid", individual="test_subject" + ).mean(skipna=True) + assert mean_dir_change.item() > 0.1