Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions tests/representation/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from iohub import open_ome_zarr
from pytest import TempPathFactory

from viscy.representation.evaluation import (
load_annotation_anndata,
)
from viscy.representation.evaluation.annotation import convert
from viscy.representation.evaluation.annotation import convert, load_annotation_anndata


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -252,7 +249,7 @@ def test_convert_xarray_annotation_to_anndata(xr_embeddings_dataset, tmp_path):


def test_load_annotation_anndata(tracks_hcs_dataset, anndata_embeddings, tmp_path):
"""Test that load_annotation_anndata correctly loads annotations from an AnnData object."""
"""Test that load_annotation_anndata correctly loads annotations into an AnnData object."""
# Load the AnnData object
adata = ad.read_zarr(anndata_embeddings)

Expand All @@ -271,16 +268,28 @@ def test_load_annotation_anndata(tracks_hcs_dataset, anndata_embeddings, tmp_pat
A11_annotations_df.to_csv(annotations_path, index=False)

# Test the function with the new CSV file
result = load_annotation_anndata(adata, str(annotations_path), "infection_state")
result_adata = load_annotation_anndata(
adata, str(annotations_path), "infection_state"
)

# Check that the function returns an AnnData object
assert isinstance(result_adata, ad.AnnData)

# Check that the annotation column was added
assert "infection_state" in result_adata.obs.columns

assert len(result) == 2 # Only 2 observations from A/1/1 have annotations
# Check that annotations were added for A/1/1 FOV
a11_mask = result_adata.obs["fov_name"] == "A/1/1"
assert a11_mask.sum() == 2 # Only 2 observations from A/1/1

# Check that the annotation values match
expected_values = A11_annotations_df["infection_state"].values
actual_values = result.values
actual_values = result_adata.obs.loc[a11_mask, "infection_state"].values
np.testing.assert_array_equal(actual_values, expected_values)

assert result.index.names == ["fov_name", "id"]
assert all(result.index.get_level_values("fov_name") == "A/1/1")
# Check that other FOVs have NaN annotations
other_fovs_mask = result_adata.obs["fov_name"] != "A/1/1"
assert result_adata.obs.loc[other_fovs_mask, "infection_state"].isna().all()


def test_cli_convert_to_anndata(xr_embeddings_dataset, tmp_path):
Expand Down
39 changes: 3 additions & 36 deletions viscy/representation/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py
"""

import anndata as ad
import pandas as pd

from viscy.data.triplet import TripletDataModule
from viscy.representation.evaluation.annotation import load_annotation_anndata

__all__ = ["load_annotation", "load_annotation_anndata", "dataset_of_tracks"]


def load_annotation(da, path, name, categories: dict | None = None):
Expand Down Expand Up @@ -63,41 +65,6 @@ def load_annotation(da, path, name, categories: dict | None = None):
return selected


def load_annotation_anndata(
adata: ad.AnnData, path: str, name: str, categories: dict | None = None
):
"""
Load annotations from a CSV file and map them to the AnnData object.

Parameters
----------
adata : anndata.AnnData
The AnnData object to map the annotations to.
path : str
Path to the CSV file containing annotations.
name : str
The column name in the CSV file to be used as annotations.
categories : dict, optional
A dictionary to rename categories in the annotation column. Default is None.
"""
annotation = pd.read_csv(path)
annotation["fov_name"] = annotation["fov_name"].str.strip("/")

annotation = annotation.set_index(["fov_name", "id"])

mi = pd.MultiIndex.from_arrays(
[adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]
)

# Use reindex to handle missing annotations gracefully
# This will return NaN for observations that don't have annotations, then just drop'em
selected = annotation.reindex(mi)[name].dropna()

if categories:
selected = selected.astype("category").cat.rename_categories(categories)
return selected


def dataset_of_tracks(
data_path,
tracks_path,
Expand Down
52 changes: 49 additions & 3 deletions viscy/representation/evaluation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ def convert(
available_cols = get_available_index_columns(embeddings_ds)
tracking_df = pd.DataFrame(
{
col: embeddings_ds.coords[col].data
if col != "fov_name"
else embeddings_ds.coords[col].to_pandas().str.strip("/")
col: (
embeddings_ds.coords[col].data
if col != "fov_name"
else embeddings_ds.coords[col].to_pandas().str.strip("/")
)
for col in available_cols
}
)
Expand Down Expand Up @@ -89,3 +91,47 @@ def convert(
adata.write_zarr(output_path)
if return_anndata:
return adata


def load_annotation_anndata(
adata: ad.AnnData, path: str, name: str, categories: dict | None = None
):
"""
Load annotations from a CSV file and map them to the AnnData object.

Parameters
----------
adata : anndata.AnnData
The AnnData object to map the annotations to.
path : str
Path to the CSV file containing annotations.
name : str
The column name in the CSV file to be used as annotations.
categories : dict, optional
A dictionary to rename categories in the annotation column. Default is None.

Returns
-------
anndata.AnnData
The AnnData object with annotations added to adata.obs[name].
"""
annotation = pd.read_csv(path)
annotation["fov_name"] = annotation["fov_name"].str.strip("/")

annotation = annotation.set_index(["fov_name", "id"])

mi = pd.MultiIndex.from_arrays(
[adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]
)

# Use reindex to handle missing annotations gracefully
# This will return NaN for observations that don't have annotations
selected = annotation.reindex(mi)[name]

if categories:
selected = selected.astype("category").cat.rename_categories(categories)

selected.index = adata.obs.index
adata.obs[name] = selected

return adata