diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index 805a02ef..948a4a59 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -1,12 +1,13 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, NamedTuple import jax import numpy as np import pandas as pd -from numba import jit +from numba import jit, prange from ott.geometry.geometry import Geometry from ott.geometry.pointcloud import PointCloud from ott.problems.linear.linear_problem import LinearProblem @@ -29,6 +30,85 @@ from anndata import AnnData +@jit(nopython=True, cache=True) +def _euclidean_distance(x: np.ndarray, y: np.ndarray) -> float: + """Compute euclidean distance between two vectors.""" + dist_sq = 0.0 + for k in range(x.shape[0]): + diff = x[k] - y[k] + dist_sq += diff * diff + return np.sqrt(dist_sq) + + +@jit(nopython=True, parallel=True, cache=True, fastmath=True) +def _euclidean_pairwise_mean_within(X: np.ndarray) -> float: + """Compute mean pairwise euclidean distance within a group (X to X).""" + n_samples = X.shape[0] + if n_samples < 2: + return 0.0 + + total_distance = 0.0 + n_pairs = n_samples * (n_samples - 1) / 2.0 + + for i in prange(n_samples): + for j in range(i + 1, n_samples): + total_distance += _euclidean_distance(X[i], X[j]) + + return total_distance / n_pairs + + +@jit(nopython=True, parallel=True, cache=True, fastmath=True) +def _euclidean_pairwise_mean_between(X: np.ndarray, Y: np.ndarray) -> float: + """Compute mean pairwise euclidean distance between two groups (X to Y).""" + n_samples_X = X.shape[0] + n_samples_Y = Y.shape[0] + + if n_samples_X == 0 or n_samples_Y == 0: + return 0.0 + + total_distance = 0.0 + n_pairs = n_samples_X * n_samples_Y + + for i in prange(n_samples_X): + for j in range(n_samples_Y): + total_distance += _euclidean_distance(X[i], Y[j]) + + return total_distance / n_pairs + + +def pairwise_distance_mean(X: np.ndarray, Y: np.ndarray | None = None, metric: str = "euclidean", **kwargs) -> float: + """Compute mean pairwise distance. Memory-efficient and fast for euclidean. + + If Y is None, computes within-group distances (X to X). + + Args: + X: First array of shape (n_samples_X, n_features). + Y: Second array of shape (n_samples_Y, n_features). If None, computes within-group distances. + metric: Distance metric to use. + kwargs: Additional keyword arguments passed to the metric function. + + Returns: + Mean pairwise distance. + """ + if metric == "euclidean": + if len(kwargs) > 0: + warnings.warn( + "kwargs are not used for euclidean distance.", + UserWarning, + stacklevel=2, + ) + if Y is None: + # Within-group distance (X to X) + return _euclidean_pairwise_mean_within(X) + else: + # Between-group distance (X to Y) + return _euclidean_pairwise_mean_between(X, Y) + elif Y is None: + return pairwise_distances(X, X, metric=metric, **kwargs).mean() + else: + return pairwise_distances(X, Y, metric=metric, **kwargs).mean() + + class MeanVar(NamedTuple): mean: float variance: float @@ -327,12 +407,49 @@ def pairwise( df_var = pd.DataFrame(index=groups, columns=groups, dtype=float) fct = track if show_progressbar else lambda iterable: iterable - # Some metrics are able to handle precomputed distances. This means that - # the pairwise distances between all cells are computed once and then - # passed to the metric function. This is much faster than computing the - # pairwise distances for each group separately. Other metrics are not - # able to handle precomputed distances such as the PseudobulkDistance. - if self.metric_fct.accepts_precomputed: + # Check if metric supports value caching (within/between distances) - more efficient than precomputed matrix + # This mode is incompatible with bootstrap since cached values would be invalid + use_value_cache = self.metric_fct.supports_value_cache() and not bootstrap + + if use_value_cache: + # Value caching mode: precompute within distances per group and between distances per pair + embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key] + + # Precompute within distances for each group + df_within = pd.Series(index=groups, dtype=float) + for group in fct(groups): + idx_group = grouping == group + cells_group = embedding[np.asarray(idx_group)] + df_within[group] = self.metric_fct.compute_within_distance(cells_group, **kwargs) + + # Precompute between distances for each pair + df_between = pd.DataFrame(index=groups, columns=groups, dtype=float) + for index_x, group_x in enumerate(fct(groups)): + idx_x = grouping == group_x + cells_x = embedding[np.asarray(idx_x)] + for group_y in groups[index_x:]: # type: ignore + if group_x == group_y: + df_between.loc[group_x, group_y] = 0.0 + else: + idx_y = grouping == group_y + cells_y = embedding[np.asarray(idx_y)] + between = self.metric_fct.compute_between_distance(cells_x, cells_y, **kwargs) + df_between.loc[group_x, group_y] = between + df_between.loc[group_y, group_x] = between + + # Compute distances from cached values + for group_x in groups: + for group_y in groups: + if group_x == group_y: + df.loc[group_x, group_y] = 0.0 + else: + dist = self.metric_fct.from_cached_values( + df_within[group_x], df_within[group_y], df_between.loc[group_x, group_y], **kwargs + ) + df.loc[group_x, group_y] = dist + + elif self.metric_fct.accepts_precomputed: + # Precomputed pairwise distance matrix mode # Precompute the pairwise distances if needed if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp: self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) @@ -364,6 +481,7 @@ def pairwise( df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance else: + # Standard mode: compute distances directly embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() for index_x, group_x in enumerate(fct(groups)): cells_x = embedding[np.asarray(grouping == group_x)].copy() @@ -461,12 +579,39 @@ def onesided_distances( df_var = pd.Series(index=groups, dtype=float) fct = track if show_progressbar else lambda iterable: iterable - # Some metrics are able to handle precomputed distances. This means that - # the pairwise distances between all cells are computed once and then - # passed to the metric function. This is much faster than computing the - # pairwise distances for each group separately. Other metrics are not - # able to handle precomputed distances such as the PseudobulkDistance. - if self.metric_fct.accepts_precomputed: + # Check if metric supports value caching (within/between distances) - more efficient than precomputed matrix + # This mode is incompatible with bootstrap since cached values would be invalid + use_value_cache = self.metric_fct.supports_value_cache() and not bootstrap + + if use_value_cache: + # Value caching mode: precompute within distances per group and between distances per pair + embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key] + + # Precompute within distance for selected_group (only need it once) + idx_selected = grouping == selected_group + cells_selected = embedding[np.asarray(idx_selected)] + within_selected = self.metric_fct.compute_within_distance(cells_selected, **kwargs) + + # Precompute within distances for each group and between distances to selected_group + for group_x in fct(groups): + if group_x == selected_group: + df.loc[group_x] = 0.0 # by distance axiom + else: + idx_x = grouping == group_x + cells_x = embedding[np.asarray(idx_x)] + + # Compute within distance for this group + within_x = self.metric_fct.compute_within_distance(cells_x, **kwargs) + + # Compute between distance to selected_group + between = self.metric_fct.compute_between_distance(cells_x, cells_selected, **kwargs) + + # Compute distance from cached values + dist = self.metric_fct.from_cached_values(within_x, within_selected, between, **kwargs) + df.loc[group_x] = dist + + elif self.metric_fct.accepts_precomputed: + # Precomputed pairwise distance matrix mode # Precompute the pairwise distances if needed if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp: self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) @@ -495,6 +640,7 @@ def onesided_distances( df.loc[group_x] = bootstrap_output.mean df_var.loc[group_x] = bootstrap_output.variance else: + # Standard mode: compute distances directly embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy() for group_x in fct(groups): cells_x = embedding[np.asarray(grouping == group_x)].copy() @@ -655,6 +801,61 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: """ raise NotImplementedError("Metric class is abstract.") + def supports_value_cache(self) -> bool: + """Whether this metric supports value-level caching (within/between distances). + + Returns: + bool: True if value caching is supported, False otherwise. + """ + return False + + def compute_within_distance(self, X: np.ndarray, **kwargs) -> float: + """Compute within-group distance statistic for caching. + + Only called if supports_value_cache() returns True. + This represents the mean pairwise distance within a single group. + + Args: + X: Vector of shape (n_samples, n_features) for a single group. + kwargs: Additional keyword arguments. + + Returns: + float: Cached within-group distance statistic. + """ + raise NotImplementedError("Metric does not support value caching.") + + def compute_between_distance(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + """Compute between-group distance statistic for caching. + + Only called if supports_value_cache() returns True. + This represents the mean pairwise distance between two groups. + + Args: + X: First vector of shape (n_samples, n_features). + Y: Second vector of shape (n_samples, n_features). + kwargs: Additional keyword arguments. + + Returns: + float: Cached between-group distance statistic. + """ + raise NotImplementedError("Metric does not support value caching.") + + def from_cached_values(self, within_X: float, within_Y: float, between: float, **kwargs) -> float: + """Compute distance using precomputed cached values. + + Only called if supports_value_cache() returns True and values have been cached. + + Args: + within_X: Precomputed within-group distance for group X. + within_Y: Precomputed within-group distance for group Y. + between: Precomputed between-group distance for pair (X, Y). + kwargs: Additional keyword arguments. + + Returns: + float: Distance between X and Y. + """ + raise NotImplementedError("Metric does not support value caching.") + class Edistance(AbstractDistance): """Edistance metric.""" @@ -665,16 +866,32 @@ def __init__(self) -> None: self.cell_wise_metric = "euclidean" def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean() - sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean() - delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean() - return 2 * delta - sigma_X - sigma_Y + within_X = pairwise_distance_mean(X, metric=self.cell_wise_metric, **kwargs) + within_Y = pairwise_distance_mean(Y, metric=self.cell_wise_metric, **kwargs) + between = pairwise_distance_mean(X, Y, metric=self.cell_wise_metric, **kwargs) + return 2 * between - within_X - within_Y def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: - sigma_X = P[idx, :][:, idx].mean() - sigma_Y = P[~idx, :][:, ~idx].mean() - delta = P[idx, :][:, ~idx].mean() - return 2 * delta - sigma_X - sigma_Y + within_X = P[idx, :][:, idx].mean() + within_Y = P[~idx, :][:, ~idx].mean() + between = P[idx, :][:, ~idx].mean() + return 2 * between - within_X - within_Y + + def supports_value_cache(self) -> bool: + """Edistance benefits from caching within and between distances.""" + return True + + def compute_within_distance(self, X: np.ndarray, **kwargs) -> float: + """Compute within-group distance (mean pairwise distance within group).""" + return pairwise_distance_mean(X, metric=self.cell_wise_metric, **kwargs) + + def compute_between_distance(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + """Compute between-group distance (mean pairwise distance between groups).""" + return pairwise_distance_mean(X, Y, metric=self.cell_wise_metric, **kwargs) + + def from_cached_values(self, within_X: float, within_Y: float, between: float, **kwargs) -> float: + """Compute edistance using cached within and between distances.""" + return 2 * between - within_X - within_Y class MMD(AbstractDistance): @@ -706,6 +923,40 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("MMD cannot be called on a pairwise distance matrix.") + def supports_value_cache(self) -> bool: + """MMD benefits from caching within and between kernel means.""" + return True + + def compute_within_distance(self, X: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs) -> float: + """Compute within-group kernel mean (mean of kernel matrix within group).""" + if kernel == "linear": + XX = np.dot(X, X.T) + elif kernel == "rbf": + XX = rbf_kernel(X, X, gamma=gamma) + elif kernel == "poly": + XX = polynomial_kernel(X, X, degree=degree, gamma=gamma, coef0=0) + else: + raise ValueError(f"Kernel {kernel} not recognized.") + return XX.mean() + + def compute_between_distance( + self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs + ) -> float: + """Compute between-group kernel mean (mean of kernel matrix between groups).""" + if kernel == "linear": + XY = np.dot(X, Y.T) + elif kernel == "rbf": + XY = rbf_kernel(X, Y, gamma=gamma) + elif kernel == "poly": + XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=0) + else: + raise ValueError(f"Kernel {kernel} not recognized.") + return XY.mean() + + def from_cached_values(self, within_X: float, within_Y: float, between: float, **kwargs) -> float: + """Compute MMD using cached within and between kernel means.""" + return within_X + within_Y - 2 * between + class WassersteinDistance(AbstractDistance): def __init__(self) -> None: @@ -810,7 +1061,7 @@ def __init__(self) -> None: self.accepts_precomputed = True def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return pairwise_distances(X, Y, **kwargs).mean() + return pairwise_distance_mean(X, Y, **kwargs) def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: return P[idx, :][:, ~idx].mean() diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index e26c177a..d781c37b 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -3,7 +3,7 @@ import copy import warnings from collections import OrderedDict -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal import matplotlib.pyplot as plt import numpy as np @@ -1220,9 +1220,12 @@ def __init__( if self.fixed_cov_indices: self.fixed_cov_values = np.array([fixed_covariances[i] for i in self.fixed_cov_indices]) - def _m_step(self, X: np.ndarray, log_resp: np.ndarray): - """Modified M-step to respect fixed means and covariances.""" - super()._m_step(X, log_resp) + def _m_step(self, X: np.ndarray, log_resp: np.ndarray, xp: Any | None = None): + """Modified M-step to respect fixed means and covariances. + + xp is the array API namespace passed by sklearn 1.6+ for backend compatibility. + """ + super()._m_step(X, log_resp, xp=xp) if self.fixed_mean_indices: self.means_[self.fixed_mean_indices] = self.fixed_mean_values