diff --git a/docs/api.rst b/docs/api.rst index 104d6fe3a..5f934baef 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -57,6 +57,18 @@ Plotting pl.extract pl.var_by_distance +Preprocessing +~~~~~~~~~~~~~ + +.. module:: squidpy.pp +.. currentmodule:: squidpy + +.. autosummary:: + :toctree: api + + pp.filter_cells + + Reading ~~~~~~~ diff --git a/docs/notebooks b/docs/notebooks index 0b092a258..296295a16 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 0b092a2580c823e296c30c9321a5a411f9fa91da +Subproject commit 296295a1682ad0f06757fe532a631803fad05c87 diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 5fb2b848a..d52a64cbd 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -3,7 +3,7 @@ from importlib import metadata from importlib.metadata import PackageMetadata -from squidpy import datasets, gr, im, pl, read, tl +from squidpy import datasets, gr, im, pl, pp, read, tl try: md: PackageMetadata = metadata.metadata(__name__) diff --git a/src/squidpy/pp/__init__.py b/src/squidpy/pp/__init__.py new file mode 100644 index 000000000..26edb04a8 --- /dev/null +++ b/src/squidpy/pp/__init__.py @@ -0,0 +1,5 @@ +"""Basic pre-processing functions adapted from scanpy.""" + +from __future__ import annotations + +from squidpy.pp._simple import filter_cells diff --git a/src/squidpy/pp/_simple.py b/src/squidpy/pp/_simple.py new file mode 100644 index 000000000..90f1d2b09 --- /dev/null +++ b/src/squidpy/pp/_simple.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import geopandas as gpd +import numpy as np +import scanpy as sc +import spatialdata as sd +from dask.dataframe import DataFrame as DaskDataFrame +from spatialdata import SpatialData, subset_sdata_by_table_mask +from spatialdata._logging import logger as logg +from spatialdata.models import ( + get_table_keys, + points_dask_dataframe_to_geopandas, + points_geopandas_to_dask_dataframe, +) + + +def filter_cells( + data: sd.SpatialData, + tables: list[str] | str | None = None, + min_counts: int | None = None, + min_genes: int | None = None, + max_counts: int | None = None, + max_genes: int | None = None, + inplace: bool = True, + filter_labels: bool = True, +) -> sd.SpatialData | None: + """\ + Squidpy's implementation of :func:`scanpy.pp.filter_cells` for :class:`anndata.AnnData` and :class:`spatialdata.SpatialData` objects. + For :class:`spatialdata.SpatialData` objects, this function filters the following elements: + + + - labels: filtered based on the values of the images which are assumed to be the instance_id. + - points: filtered based on the index which is assumed to be the instance_id. + - shapes: filtered based on the instance_id column. + + + See :func:`scanpy.pp.filter_cells` for more details regarding the filtering + behavior. + + Parameters + ---------- + data + :class:`spatialdata.SpatialData` object. + tables + If :class:`spatialdata.SpatialData` object, the tables to filter. If `None`, all tables are filtered. + min_counts + Minimum number of counts required for a cell to pass filtering. + min_genes + Minimum number of genes expressed required for a cell to pass filtering. + max_counts + Maximum number of counts required for a cell to pass filtering. + max_genes + Maximum number of genes expressed required for a cell to pass filtering. + inplace + Perform computation inplace or return result. + filter_labels + Whether to filter labels. If `True`, then labels are filtered based on the instance_id column. + + Returns + ------- + If `inplace` then returns `None`, otherwise returns the filtered :class:`spatialdata.SpatialData` object. + """ + if not isinstance(data, sd.SpatialData): + raise ValueError( + f"Expected `SpatialData`, found `{type(data)}` instead. Perhaps you want to use `scanpy.pp.filter_cells` instead." + ) + + return _filter_cells_spatialdata(data, tables, min_counts, min_genes, max_counts, max_genes, inplace, filter_labels) + + +def _get_only_annotated_shape(sdata: sd.SpatialData, table_name: str) -> str | None: + table = sdata.tables[table_name] + + # only one shape needs to be annotated to filter points within it + # other annotations can't be points + + regions, _, _ = get_table_keys(table) + if len(regions) == 0: + return None + + if isinstance(regions, str): + regions = [regions] + + res = None + for r in regions: + if r in sdata.points: + return None + if r in sdata.shapes: + if res is not None: + return None + res = r + + return res + + +def _annotated_points_by_shape_membership( + sdata: SpatialData, + point_key: str, + shape_key: str, +) -> DaskDataFrame: + """Annotate points by shape membership. + + Parameters + ---------- + sdata + The SpatialData object to annotate. + point_key + The key of the points to annotate. + shape_key + The key of the shapes to annotate. + + Returns + ------- + The annotated points. + """ + points = sdata.points[point_key] + shapes = sdata.shapes[shape_key] + points_gdf = points_dask_dataframe_to_geopandas(points) + res = points_gdf.sjoin(shapes, how="left", predicate="within") + return points_geopandas_to_dask_dataframe(res) + + +def _filter_cells_spatialdata( + data: sd.SpatialData, + tables: list[str] | str | None = None, + min_counts: int | None = None, + min_genes: int | None = None, + max_counts: int | None = None, + max_genes: int | None = None, + inplace: bool = True, + filter_labels: bool = True, +) -> sd.SpatialData | None: + if isinstance(tables, str): + tables = [tables] + elif tables is None: + tables = list(data.tables.keys()) + + if len(tables) == 0: + raise ValueError("Expected at least one table to be filtered, found `0`") + + if not all(t in data.tables for t in tables): + raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.") + + for t in tables: + if "spatialdata_attrs" not in data.tables[t].uns: + raise ValueError(f"Table `{t}` does not have 'spatialdata_attrs' to indicate what it annotates.") + + if not inplace: + logg.warning( + "Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while." + ) + data_out = sd.deepcopy(data) + else: + data_out = data + + for t in tables: + table_old = data_out.tables[t] + mask_filtered, _ = sc.pp.filter_cells( + table_old, + min_counts=min_counts, + min_genes=min_genes, + max_counts=max_counts, + max_genes=max_genes, + inplace=False, + ) + if mask_filtered.sum() == 0: + raise ValueError(f"Filter results in empty table when filtering table `{t}`.") + sdata_filtered = subset_sdata_by_table_mask(sdata=data_out, table_name=t, mask=mask_filtered) + data_out.tables[t] = sdata_filtered.tables[t] + for k in list(sdata_filtered.points.keys()): + data_out.points[k] = sdata_filtered.points[k] + for k in list(sdata_filtered.shapes.keys()): + data_out.shapes[k] = sdata_filtered.shapes[k] + if filter_labels: + for k in list(sdata_filtered.labels.keys()): + data_out.labels[k] = sdata_filtered.labels[k] + shape_name = _get_only_annotated_shape(data_out, t) + if shape_name is not None: + for p in data_out.points: + _, _, instance_key = get_table_keys(table_old) + shape_index_name = data_out.shapes[shape_name].index.name + new_points = _annotated_points_by_shape_membership( + sdata=data_out, + shape_key=shape_name, + point_key=p, + ) + shape_index_name += "_right" + removed_instance_ids = list(np.unique(table_old.obs[instance_key][~mask_filtered])) + # drop points that are not in any shape + new_points = new_points.dropna() + # drop points that are in the removed_instance_ids + new_points = new_points[~new_points[shape_index_name].isin(removed_instance_ids)] + data_out.points[p] = new_points + + if inplace: + return None + return data_out diff --git a/tests/preprocessing/test_simple.py b/tests/preprocessing/test_simple.py new file mode 100644 index 000000000..82daaf9bf --- /dev/null +++ b/tests/preprocessing/test_simple.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import anndata as ad +import numpy as np +import pytest +import scanpy as sc +from spatialdata.datasets import blobs_annotating_element + +import squidpy as sq + + +def _make_sdata(name: str, num_counts: int, count_value: int): + assert num_counts <= 5, "num_counts must be less than 5" + sdata_temp = blobs_annotating_element(name) + m, _ = sdata_temp.tables["table"].shape + n = m + X = np.zeros((m, n)) + # random choice of row + row_indices = np.random.choice(m, num_counts, replace=False) + col_indices = np.random.choice(n, num_counts, replace=False) + X[row_indices, col_indices] = count_value + + sdata_temp.tables["table"] = ad.AnnData( + X=X, + obs=sdata_temp.tables["table"].obs, + var={"gene": ["gene" for _ in range(n)]}, + uns=sdata_temp.tables["table"].uns, + ) + return sdata_temp + + +@pytest.mark.parametrize("name", ["blobs_labels", "blobs_circles", "blobs_points", "blobs_multiscale_labels"]) +def test_filter_cells(name: str): + filtered_cells = 3 + sdata = _make_sdata(name, num_counts=filtered_cells, count_value=100) + num_cells = sdata.tables["table"].shape[0] + adata_copy = sdata.tables["table"].copy() + sc.pp.filter_cells(adata_copy, max_counts=50, inplace=True) + sq.pp.filter_cells(sdata, max_counts=50, inplace=True, filter_labels=True) + + assert np.all(sdata.tables["table"].X == adata_copy.X), "Filtered cells are not the same as scanpy" + assert np.all(sdata.tables["table"].obs["instance_id"] == adata_copy.obs["instance_id"]), ( + "Filtered cells are not the same as scanpy" + ) + assert sdata.tables["table"].shape[0] == (num_cells - filtered_cells), ( + f"Expected {num_cells - filtered_cells} cells, got {sdata.tables['table'].shape[0]}" + ) + + if name == "blobs_labels": + unique_labels = np.unique(adata_copy.obs["instance_id"]) + unique_labels_sdata = np.unique(sdata.labels["blobs_labels"].data.compute()) + assert set(unique_labels) == set(unique_labels_sdata).difference([0]), ( + f"Filtered labels {unique_labels} are not the same as scanpy {unique_labels_sdata}" + ) + + +def test_filter_cells_empty_fail(): + sdata = _make_sdata("blobs_labels", num_counts=5, count_value=200) + with pytest.raises(ValueError, match="Filter results in empty table when filtering table `table`."): + sq.pp.filter_cells(sdata, max_counts=100, inplace=True) diff --git a/tests/utils/test_parallelize.py b/tests/utils/test_parallelize.py index 88922b526..1b30b63ef 100644 --- a/tests/utils/test_parallelize.py +++ b/tests/utils/test_parallelize.py @@ -67,7 +67,7 @@ def func(request) -> Callable: # in case of failure. -@pytest.mark.timeout(30) +@pytest.mark.timeout(50) @pytest.mark.parametrize( "backend", [