Skip to content

Commit c91dd24

Browse files
authored
Move and fix annotations for AnnData (#347)
1 parent b7fbf75 commit c91dd24

File tree

3 files changed

+71
-49
lines changed

3 files changed

+71
-49
lines changed

tests/representation/test_annotations.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from iohub import open_ome_zarr
1212
from pytest import TempPathFactory
1313

14-
from viscy.representation.evaluation import (
15-
load_annotation_anndata,
16-
)
17-
from viscy.representation.evaluation.annotation import convert
14+
from viscy.representation.evaluation.annotation import convert, load_annotation_anndata
1815

1916

2017
@pytest.fixture(scope="function")
@@ -252,7 +249,7 @@ def test_convert_xarray_annotation_to_anndata(xr_embeddings_dataset, tmp_path):
252249

253250

254251
def test_load_annotation_anndata(tracks_hcs_dataset, anndata_embeddings, tmp_path):
255-
"""Test that load_annotation_anndata correctly loads annotations from an AnnData object."""
252+
"""Test that load_annotation_anndata correctly loads annotations into an AnnData object."""
256253
# Load the AnnData object
257254
adata = ad.read_zarr(anndata_embeddings)
258255

@@ -271,16 +268,28 @@ def test_load_annotation_anndata(tracks_hcs_dataset, anndata_embeddings, tmp_pat
271268
A11_annotations_df.to_csv(annotations_path, index=False)
272269

273270
# Test the function with the new CSV file
274-
result = load_annotation_anndata(adata, str(annotations_path), "infection_state")
271+
result_adata = load_annotation_anndata(
272+
adata, str(annotations_path), "infection_state"
273+
)
274+
275+
# Check that the function returns an AnnData object
276+
assert isinstance(result_adata, ad.AnnData)
277+
278+
# Check that the annotation column was added
279+
assert "infection_state" in result_adata.obs.columns
275280

276-
assert len(result) == 2 # Only 2 observations from A/1/1 have annotations
281+
# Check that annotations were added for A/1/1 FOV
282+
a11_mask = result_adata.obs["fov_name"] == "A/1/1"
283+
assert a11_mask.sum() == 2 # Only 2 observations from A/1/1
277284

285+
# Check that the annotation values match
278286
expected_values = A11_annotations_df["infection_state"].values
279-
actual_values = result.values
287+
actual_values = result_adata.obs.loc[a11_mask, "infection_state"].values
280288
np.testing.assert_array_equal(actual_values, expected_values)
281289

282-
assert result.index.names == ["fov_name", "id"]
283-
assert all(result.index.get_level_values("fov_name") == "A/1/1")
290+
# Check that other FOVs have NaN annotations
291+
other_fovs_mask = result_adata.obs["fov_name"] != "A/1/1"
292+
assert result_adata.obs.loc[other_fovs_mask, "infection_state"].isna().all()
284293

285294

286295
def test_cli_convert_to_anndata(xr_embeddings_dataset, tmp_path):

viscy/representation/evaluation/__init__.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py
1515
"""
1616

17-
import anndata as ad
1817
import pandas as pd
1918

2019
from viscy.data.triplet import TripletDataModule
20+
from viscy.representation.evaluation.annotation import load_annotation_anndata
21+
22+
__all__ = ["load_annotation", "load_annotation_anndata", "dataset_of_tracks"]
2123

2224

2325
def load_annotation(da, path, name, categories: dict | None = None):
@@ -63,41 +65,6 @@ def load_annotation(da, path, name, categories: dict | None = None):
6365
return selected
6466

6567

66-
def load_annotation_anndata(
67-
adata: ad.AnnData, path: str, name: str, categories: dict | None = None
68-
):
69-
"""
70-
Load annotations from a CSV file and map them to the AnnData object.
71-
72-
Parameters
73-
----------
74-
adata : anndata.AnnData
75-
The AnnData object to map the annotations to.
76-
path : str
77-
Path to the CSV file containing annotations.
78-
name : str
79-
The column name in the CSV file to be used as annotations.
80-
categories : dict, optional
81-
A dictionary to rename categories in the annotation column. Default is None.
82-
"""
83-
annotation = pd.read_csv(path)
84-
annotation["fov_name"] = annotation["fov_name"].str.strip("/")
85-
86-
annotation = annotation.set_index(["fov_name", "id"])
87-
88-
mi = pd.MultiIndex.from_arrays(
89-
[adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]
90-
)
91-
92-
# Use reindex to handle missing annotations gracefully
93-
# This will return NaN for observations that don't have annotations, then just drop'em
94-
selected = annotation.reindex(mi)[name].dropna()
95-
96-
if categories:
97-
selected = selected.astype("category").cat.rename_categories(categories)
98-
return selected
99-
100-
10168
def dataset_of_tracks(
10269
data_path,
10370
tracks_path,

viscy/representation/evaluation/annotation.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ def convert(
5959
available_cols = get_available_index_columns(embeddings_ds)
6060
tracking_df = pd.DataFrame(
6161
{
62-
col: embeddings_ds.coords[col].data
63-
if col != "fov_name"
64-
else embeddings_ds.coords[col].to_pandas().str.strip("/")
62+
col: (
63+
embeddings_ds.coords[col].data
64+
if col != "fov_name"
65+
else embeddings_ds.coords[col].to_pandas().str.strip("/")
66+
)
6567
for col in available_cols
6668
}
6769
)
@@ -89,3 +91,47 @@ def convert(
8991
adata.write_zarr(output_path)
9092
if return_anndata:
9193
return adata
94+
95+
96+
def load_annotation_anndata(
97+
adata: ad.AnnData, path: str, name: str, categories: dict | None = None
98+
):
99+
"""
100+
Load annotations from a CSV file and map them to the AnnData object.
101+
102+
Parameters
103+
----------
104+
adata : anndata.AnnData
105+
The AnnData object to map the annotations to.
106+
path : str
107+
Path to the CSV file containing annotations.
108+
name : str
109+
The column name in the CSV file to be used as annotations.
110+
categories : dict, optional
111+
A dictionary to rename categories in the annotation column. Default is None.
112+
113+
Returns
114+
-------
115+
anndata.AnnData
116+
The AnnData object with annotations added to adata.obs[name].
117+
"""
118+
annotation = pd.read_csv(path)
119+
annotation["fov_name"] = annotation["fov_name"].str.strip("/")
120+
121+
annotation = annotation.set_index(["fov_name", "id"])
122+
123+
mi = pd.MultiIndex.from_arrays(
124+
[adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]
125+
)
126+
127+
# Use reindex to handle missing annotations gracefully
128+
# This will return NaN for observations that don't have annotations
129+
selected = annotation.reindex(mi)[name]
130+
131+
if categories:
132+
selected = selected.astype("category").cat.rename_categories(categories)
133+
134+
selected.index = adata.obs.index
135+
adata.obs[name] = selected
136+
137+
return adata

0 commit comments

Comments
 (0)