Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@


@pytest.fixture
def rng():
def rng() -> np.random.Generator:
return np.random.default_rng()
10 changes: 6 additions & 4 deletions tests/tools/_distances/test_distance_tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
import scanpy as sc
from anndata import AnnData
from pandas import DataFrame

import pertpy as pt
from pertpy.tools._distances._distances import Metric

distances = [
distances: tuple[Metric, ...] = (
"edistance",
"euclidean",
"mse",
Expand All @@ -25,21 +27,21 @@
# "nbll",
"mahalanobis",
"mean_var_distribution",
]
)

count_distances = ["nb_ll"]


@pytest.fixture
def adata():
def adata() -> AnnData:
adata = pt.dt.distance_example()
adata = sc.pp.subsample(adata, 0.1, copy=True)

return adata


@pytest.mark.parametrize("distance", distances)
def test_distancetest(adata, distance):
def test_distancetest(adata: AnnData, distance: Metric) -> None:
etest = pt.tl.DistanceTest(distance, n_perms=10, obsm_key="X_pca", alpha=0.05, correction="holm-sidak")
tab = etest(adata, groupby="perturbation", contrast="control")

Expand Down
109 changes: 58 additions & 51 deletions tests/tools/_distances/test_distances.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
import pytest
import scanpy as sc
from anndata import AnnData
from pandas import DataFrame, Series
from pytest import fixture, mark

import pertpy as pt
from pertpy.tools._distances._distances import Distance, Metric

actual_distances = [
actual_distances: tuple[Metric, ...] = (
# Euclidean distances and related
"euclidean",
"mean_absolute_error",
Expand All @@ -22,31 +24,34 @@
"t_test",
"wasserstein",
"mahalanobis",
]
semi_distances = ["r2_distance", "sym_kldiv", "ks_test"]
non_distances = ["classifier_proba"]
onesided_only = ["classifier_cp"]
pseudo_counts_distances = ["nb_ll"]
lognorm_counts_distances = ["mean_var_distribution"]
all_distances = (
actual_distances + semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances
) # + onesided_only
)
semi_distances: tuple[Metric, ...] = ("r2_distance", "sym_kldiv", "ks_test")
non_distances: tuple[Metric, ...] = ("classifier_proba",)
onesided_only: tuple[Metric, ...] = ("classifier_cp",)
pseudo_counts_distances: tuple[Metric, ...] = ("nb_ll",)
lognorm_counts_distances: tuple[Metric, ...] = ("mean_var_distribution",)
all_distances: tuple[Metric, ...] = (
*actual_distances,
*semi_distances,
*non_distances,
*lognorm_counts_distances,
*pseudo_counts_distances,
# *onesided_only,
)


@fixture
def adata(request):
low_subsample_distances = [
def adata(distance: Metric, rng: np.random.Generator) -> AnnData:
low_subsample_distances = {
"sym_kldiv",
"t_test",
"ks_test",
"classifier_proba",
"classifier_cp",
"mahalanobis",
"mean_var_distribution",
]
no_subsample_distances = ["mahalanobis"] # mahalanobis only works on the full data without subsampling

distance = request.node.callspec.params["distance"]
}
no_subsample_distances = {"mahalanobis"} # mahalanobis only works on the full data without subsampling

adata = pt.dt.distance_example()
if distance not in no_subsample_distances:
Expand All @@ -55,7 +60,7 @@ def adata(request):
else:
adata = sc.pp.subsample(adata, 0.001, copy=True)

adata = adata[:, np.random.default_rng().choice(adata.n_vars, 100, replace=False)].copy()
adata = adata[:, rng.choice(adata.n_vars, 100, replace=False)].copy()

adata.layers["lognorm"] = adata.X.copy()
adata.layers["counts"] = np.round(adata.X.toarray()).astype(int)
Expand All @@ -70,25 +75,23 @@ def adata(request):


@fixture
def distance_obj(request):
distance = request.node.callspec.params["distance"]
def distance_obj(distance: Metric) -> pt.tl.Distance:
if distance in lognorm_counts_distances:
Distance = pt.tl.Distance(distance, layer_key="lognorm")
elif distance in pseudo_counts_distances:
Distance = pt.tl.Distance(distance, layer_key="counts")
else:
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
return Distance
return pt.tl.Distance(distance, layer_key="lognorm")
if distance in pseudo_counts_distances:
return pt.tl.Distance(distance, layer_key="counts")
return pt.tl.Distance(distance, obsm_key="X_pca")


@fixture
@mark.parametrize("distance", all_distances)
def pairwise_distance(adata, distance_obj, distance):
def pairwise_distance(adata: AnnData, distance_obj: pt.tl.Distance) -> DataFrame:
return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=True)


@mark.parametrize("distance", actual_distances + semi_distances)
def test_distance_axioms(pairwise_distance, distance):
def test_distance_axioms(pairwise_distance: DataFrame, distance: Metric) -> None:
del distance

# This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality.
# (M1) Definiteness
assert all(np.diag(pairwise_distance.values) == 0) # distance to self is 0
Expand All @@ -102,12 +105,12 @@ def test_distance_axioms(pairwise_distance, distance):


@mark.parametrize("distance", actual_distances)
def test_triangle_inequality(pairwise_distance, distance, rng):
# Test if distances are well-defined in accordance with metric axioms
# (M4) Triangle inequality (we just probe this for a few random triplets)
# Some tests are not well defined for the triangle inequality. We skip those.
def test_triangle_inequality(pairwise_distance: DataFrame, distance: Metric, rng: np.random.Generator) -> None:
"""Test if distances are well-defined in accordance with metric axioms
(M4) Triangle inequality (we just probe this for a few random triplets)
"""
if distance in {"mahalanobis", "wasserstein"}:
return
pytest.skip("Some tests not well defined for triangle inequality")

for _ in range(5):
triplet = rng.choice(pairwise_distance.index, size=3, replace=False)
Expand All @@ -118,30 +121,33 @@ def test_triangle_inequality(pairwise_distance, distance, rng):


@mark.parametrize("distance", all_distances)
def test_distance_layers(pairwise_distance, distance):
def test_distance_layers(pairwise_distance: DataFrame, distance: Metric) -> None:
del distance

assert isinstance(pairwise_distance, DataFrame)
assert pairwise_distance.columns.equals(pairwise_distance.index)
assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 # symmetry


@mark.parametrize("distance", actual_distances + pseudo_counts_distances)
def test_distance_counts(adata, distance):
if distance != "mahalanobis": # skip, doesn't work because covariance matrix is a singular matrix, not invertible
distance = pt.tl.Distance(distance, layer_key="counts")
df = distance.pairwise(adata, groupby="perturbation")
assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0
def test_distance_counts(adata: AnnData, distance: Metric) -> None:
if distance == "mahalanobis":
pytest.skip("covariance matrix is a singular matrix, not invertible")
distance_obj = pt.tl.Distance(distance, layer_key="counts")
df = distance_obj.pairwise(adata, groupby="perturbation")
assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0


@mark.parametrize("distance", all_distances)
def test_mutually_exclusive_keys(distance):
def test_mutually_exclusive_keys(distance: Metric) -> None:
with pytest.raises(ValueError):
_ = pt.tl.Distance(distance, layer_key="counts", obsm_key="X_pca")


@mark.parametrize("distance", actual_distances + semi_distances + non_distances)
def test_distance_output_type(distance, rng):
def test_distance_output_type(distance: Metric, rng: np.random.Generator) -> None:
# Test if distances are outputting floats
Distance = pt.tl.Distance(distance)
X = rng.normal(size=(50, 10))
Expand All @@ -151,15 +157,16 @@ def test_distance_output_type(distance, rng):


@mark.parametrize("distance", all_distances + onesided_only)
def test_distance_onesided(adata, distance_obj, distance):
def test_distance_onesided(adata: AnnData, distance_obj: Distance, distance: Metric) -> None:
del distance
# Test consistency of one-sided distance results
selected_group = adata.obs.perturbation.unique()[0]
selected_group = adata.obs["perturbation"].unique()[0]
df = distance_obj.onesided_distances(adata, groupby="perturbation", selected_group=selected_group)
assert isinstance(df, Series)
assert df.loc[selected_group] == 0 # distance to self is 0


def test_bootstrap_distance_output_type(rng):
def test_bootstrap_distance_output_type(rng: np.random.Generator) -> None:
# Test if distances are outputting floats
Distance = pt.tl.Distance(metric="edistance")
X = rng.normal(size=(50, 10))
Expand All @@ -170,7 +177,7 @@ def test_bootstrap_distance_output_type(rng):


@mark.parametrize("distance", ["edistance"])
def test_bootstrap_distance_pairwise(adata, distance):
def test_bootstrap_distance_pairwise(adata: AnnData, distance: Metric) -> None:
# Test consistency of pairwise distance results
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
bootstrap_output = Distance.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3)
Expand All @@ -186,9 +193,9 @@ def test_bootstrap_distance_pairwise(adata, distance):


@mark.parametrize("distance", ["edistance"])
def test_bootstrap_distance_onesided(adata, distance):
def test_bootstrap_distance_onesided(adata: AnnData, distance: Metric) -> None:
# Test consistency of one-sided distance results
selected_group = adata.obs.perturbation.unique()[0]
selected_group = adata.obs["perturbation"].unique()[0]
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
bootstrap_output = Distance.onesided_distances(
adata,
Expand All @@ -201,7 +208,7 @@ def test_bootstrap_distance_onesided(adata, distance):
assert isinstance(bootstrap_output, tuple)


def test_compare_distance(rng):
def test_compare_distance(rng: np.random.Generator) -> None:
X = rng.normal(size=(50, 10))
Y = rng.normal(size=(50, 10))
C = rng.normal(size=(50, 10))
Expand All @@ -211,4 +218,4 @@ def test_compare_distance(rng):
res_scaled = Distance.compare_distance(X, Y, C, mode="scaled")
assert isinstance(res_scaled, float)
with pytest.raises(ValueError):
Distance.compare_distance(X, Y, C, mode="new_mode")
Distance.compare_distance(X, Y, C, mode="new_mode") # type: ignore[arg-type]
Loading