Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ Plotting
pl.extract
pl.var_by_distance

Preprocessing
~~~~~~~~~~~~~

.. module:: squidpy.pp
.. currentmodule:: squidpy

.. autosummary::
:toctree: api

pp.filter_cells


Reading
~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion src/squidpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib import metadata
from importlib.metadata import PackageMetadata

from squidpy import datasets, gr, im, pl, read, tl
from squidpy import datasets, gr, im, pl, pp, read, tl

try:
md: PackageMetadata = metadata.metadata(__name__)
Expand Down
5 changes: 5 additions & 0 deletions src/squidpy/pp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Basic pre-processing functions adapted from scanpy."""

from __future__ import annotations

from squidpy.pp._simple import filter_cells
197 changes: 197 additions & 0 deletions src/squidpy/pp/_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from __future__ import annotations

import geopandas as gpd
import numpy as np
import scanpy as sc
import spatialdata as sd
from dask.dataframe import DataFrame as DaskDataFrame
from spatialdata import SpatialData, subset_sdata_by_table_mask
from spatialdata._logging import logger as logg
from spatialdata.models import (
get_table_keys,
points_dask_dataframe_to_geopandas,
points_geopandas_to_dask_dataframe,
)


def filter_cells(
data: sd.SpatialData,
tables: list[str] | str | None = None,
min_counts: int | None = None,
min_genes: int | None = None,
max_counts: int | None = None,
max_genes: int | None = None,
inplace: bool = True,
filter_labels: bool = True,
) -> sd.SpatialData | None:
"""\
Squidpy's implementation of :func:`scanpy.pp.filter_cells` for :class:`anndata.AnnData` and :class:`spatialdata.SpatialData` objects.
For :class:`spatialdata.SpatialData` objects, this function filters the following elements:


- labels: filtered based on the values of the images which are assumed to be the instance_id.
- points: filtered based on the index which is assumed to be the instance_id.
- shapes: filtered based on the instance_id column.


See :func:`scanpy.pp.filter_cells` for more details regarding the filtering
behavior.

Parameters
----------
data
:class:`spatialdata.SpatialData` object.
tables
If :class:`spatialdata.SpatialData` object, the tables to filter. If `None`, all tables are filtered.
min_counts
Minimum number of counts required for a cell to pass filtering.
min_genes
Minimum number of genes expressed required for a cell to pass filtering.
max_counts
Maximum number of counts required for a cell to pass filtering.
max_genes
Maximum number of genes expressed required for a cell to pass filtering.
inplace
Perform computation inplace or return result.
filter_labels
Whether to filter labels. If `True`, then labels are filtered based on the instance_id column.

Returns
-------
If `inplace` then returns `None`, otherwise returns the filtered :class:`spatialdata.SpatialData` object.
"""
if not isinstance(data, sd.SpatialData):
raise ValueError(
f"Expected `SpatialData`, found `{type(data)}` instead. Perhaps you want to use `scanpy.pp.filter_cells` instead."
)

return _filter_cells_spatialdata(data, tables, min_counts, min_genes, max_counts, max_genes, inplace, filter_labels)


def _get_only_annotated_shape(sdata: sd.SpatialData, table_name: str) -> str | None:
table = sdata.tables[table_name]

# only one shape needs to be annotated to filter points within it
# other annotations can't be points

regions, _, _ = get_table_keys(table)
if len(regions) == 0:
return None

if isinstance(regions, str):
regions = [regions]

res = None
for r in regions:
if r in sdata.points:
return None
if r in sdata.shapes:
if res is not None:
return None
res = r

return res


def _annotated_points_by_shape_membership(
sdata: SpatialData,
point_key: str,
shape_key: str,
) -> DaskDataFrame:
"""Annotate points by shape membership.

Parameters
----------
sdata
The SpatialData object to annotate.
point_key
The key of the points to annotate.
shape_key
The key of the shapes to annotate.

Returns
-------
The annotated points.
"""
points = sdata.points[point_key]
shapes = sdata.shapes[shape_key]
points_gdf = points_dask_dataframe_to_geopandas(points)
res = points_gdf.sjoin(shapes, how="left", predicate="within")
return points_geopandas_to_dask_dataframe(res)


def _filter_cells_spatialdata(
data: sd.SpatialData,
tables: list[str] | str | None = None,
min_counts: int | None = None,
min_genes: int | None = None,
max_counts: int | None = None,
max_genes: int | None = None,
inplace: bool = True,
filter_labels: bool = True,
) -> sd.SpatialData | None:
if isinstance(tables, str):
tables = [tables]
elif tables is None:
tables = list(data.tables.keys())

if len(tables) == 0:
raise ValueError("Expected at least one table to be filtered, found `0`")

if not all(t in data.tables for t in tables):
raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.")

for t in tables:
if "spatialdata_attrs" not in data.tables[t].uns:
raise ValueError(f"Table `{t}` does not have 'spatialdata_attrs' to indicate what it annotates.")

if not inplace:
logg.warning(
"Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while."
)
data_out = sd.deepcopy(data)
else:
data_out = data

for t in tables:
table_old = data_out.tables[t]
mask_filtered, _ = sc.pp.filter_cells(
table_old,
min_counts=min_counts,
min_genes=min_genes,
max_counts=max_counts,
max_genes=max_genes,
inplace=False,
)
if mask_filtered.sum() == 0:
raise ValueError(f"Filter results in empty table when filtering table `{t}`.")
sdata_filtered = subset_sdata_by_table_mask(sdata=data_out, table_name=t, mask=mask_filtered)
data_out.tables[t] = sdata_filtered.tables[t]
for k in list(sdata_filtered.points.keys()):
data_out.points[k] = sdata_filtered.points[k]
for k in list(sdata_filtered.shapes.keys()):
data_out.shapes[k] = sdata_filtered.shapes[k]
if filter_labels:
for k in list(sdata_filtered.labels.keys()):
data_out.labels[k] = sdata_filtered.labels[k]
shape_name = _get_only_annotated_shape(data_out, t)
if shape_name is not None:
for p in data_out.points:
_, _, instance_key = get_table_keys(table_old)
shape_index_name = data_out.shapes[shape_name].index.name
new_points = _annotated_points_by_shape_membership(
sdata=data_out,
shape_key=shape_name,
point_key=p,
)
shape_index_name += "_right"
removed_instance_ids = list(np.unique(table_old.obs[instance_key][~mask_filtered]))
# drop points that are not in any shape
new_points = new_points.dropna()
# drop points that are in the removed_instance_ids
new_points = new_points[~new_points[shape_index_name].isin(removed_instance_ids)]
data_out.points[p] = new_points

if inplace:
return None
return data_out
60 changes: 60 additions & 0 deletions tests/preprocessing/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import anndata as ad
import numpy as np
import pytest
import scanpy as sc
from spatialdata.datasets import blobs_annotating_element

import squidpy as sq


def _make_sdata(name: str, num_counts: int, count_value: int):
assert num_counts <= 5, "num_counts must be less than 5"
sdata_temp = blobs_annotating_element(name)
m, _ = sdata_temp.tables["table"].shape
n = m
X = np.zeros((m, n))
# random choice of row
row_indices = np.random.choice(m, num_counts, replace=False)
col_indices = np.random.choice(n, num_counts, replace=False)
X[row_indices, col_indices] = count_value

sdata_temp.tables["table"] = ad.AnnData(
X=X,
obs=sdata_temp.tables["table"].obs,
var={"gene": ["gene" for _ in range(n)]},
uns=sdata_temp.tables["table"].uns,
)
return sdata_temp


@pytest.mark.parametrize("name", ["blobs_labels", "blobs_circles", "blobs_points", "blobs_multiscale_labels"])
def test_filter_cells(name: str):
filtered_cells = 3
sdata = _make_sdata(name, num_counts=filtered_cells, count_value=100)
num_cells = sdata.tables["table"].shape[0]
adata_copy = sdata.tables["table"].copy()
sc.pp.filter_cells(adata_copy, max_counts=50, inplace=True)
sq.pp.filter_cells(sdata, max_counts=50, inplace=True, filter_labels=True)

assert np.all(sdata.tables["table"].X == adata_copy.X), "Filtered cells are not the same as scanpy"
assert np.all(sdata.tables["table"].obs["instance_id"] == adata_copy.obs["instance_id"]), (
"Filtered cells are not the same as scanpy"
)
assert sdata.tables["table"].shape[0] == (num_cells - filtered_cells), (
f"Expected {num_cells - filtered_cells} cells, got {sdata.tables['table'].shape[0]}"
)

if name == "blobs_labels":
unique_labels = np.unique(adata_copy.obs["instance_id"])
unique_labels_sdata = np.unique(sdata.labels["blobs_labels"].data.compute())
assert set(unique_labels) == set(unique_labels_sdata).difference([0]), (
f"Filtered labels {unique_labels} are not the same as scanpy {unique_labels_sdata}"
)


def test_filter_cells_empty_fail():
sdata = _make_sdata("blobs_labels", num_counts=5, count_value=200)
with pytest.raises(ValueError, match="Filter results in empty table when filtering table `table`."):
sq.pp.filter_cells(sdata, max_counts=100, inplace=True)
2 changes: 1 addition & 1 deletion tests/utils/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def func(request) -> Callable:
# in case of failure.


@pytest.mark.timeout(30)
@pytest.mark.timeout(50)
@pytest.mark.parametrize(
"backend",
[
Expand Down
Loading