Skip to content

Commit

Permalink
Daniel strobl hvg conservation fix (#785)
Browse files Browse the repository at this point in the history
* hvg conservation metric fix

* pre-commit

* Allow for uppercase repo owner

* Fix sklearn req

* bash not sh

* bugfix use index

* add to api

* pre-commit

* list instead of index

* check number of genes

* pre-commit

* addressing comments

* pre-commit

* shorten line

* addressing comments

* pre-commit

* fix checks

* remove magic numbers

* pre-commit

* int -> numbers.Integral

* Fix typo

* fix dataset size assumption and duck-type hvg_unint

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Scott Gigante <[email protected]>
Co-authored-by: Scott Gigante <[email protected]>
Former-commit-id: 814fedc
  • Loading branch information
4 people authored Feb 2, 2023
1 parent c5e6f56 commit 49c83bf
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 6 deletions.
23 changes: 22 additions & 1 deletion openproblems/tasks/_batch_integration/_common/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from ....data.sample import load_sample_data
from ....tools.decorators import dataset
from .utils import filter_celltypes
from .utils import precompute_hvg

import numbers
import numpy as np

MIN_CELLS_PER_CELLTYPE = 50
N_HVG_UNINT = 2000


def check_neighbors(adata, neighbors_key, connectivities_key, distances_key):
Expand All @@ -15,7 +18,12 @@ def check_neighbors(adata, neighbors_key, connectivities_key, distances_key):
assert distances_key in adata.obsp


def check_dataset(adata, do_check_pca=False, do_check_neighbors=False):
def check_dataset(
adata,
do_check_pca=False,
do_check_neighbors=False,
do_check_hvg=False,
):
"""Check that dataset output fits expected API."""

assert "batch" in adata.obs
Expand All @@ -28,12 +36,21 @@ def check_dataset(adata, do_check_pca=False, do_check_neighbors=False):
assert adata.var_names.is_unique
assert adata.obs_names.is_unique

assert "n_genes_pre" in adata.uns
assert isinstance(adata.uns["n_genes_pre"], numbers.Integral)
assert adata.uns["n_genes_pre"] == adata.n_vars

assert "organism" in adata.uns
assert adata.uns["organism"] in ["mouse", "human"]

if do_check_pca:
assert "X_uni_pca" in adata.obsm

if do_check_hvg:
assert "hvg_unint" in adata.uns
assert len(adata.uns["hvg_unint"]) == min(N_HVG_UNINT, adata.n_vars)
assert np.all(np.isin(adata.uns["hvg_unint"], adata.var.index))

if do_check_neighbors:
check_neighbors(adata, "uni", "uni_connectivities", "uni_distances")

Expand All @@ -58,6 +75,10 @@ def sample_dataset(run_pca: bool = False, run_neighbors: bool = False):
adata.obs["batch"] = np.random.choice(2, adata.shape[0], replace=True).astype(str)
adata.obs["labels"] = np.random.choice(3, adata.shape[0], replace=True).astype(str)
adata = filter_celltypes(adata)

adata.uns["hvg_unint"] = precompute_hvg(adata)
adata.uns["n_genes_pre"] = adata.n_vars

if run_pca:
adata.obsm["X_uni_pca"] = sc.pp.pca(adata.X)
if run_neighbors:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .....data.immune_cells import load_immune
from .....tools.decorators import dataset
from ..utils import filter_celltypes
from ..utils import precompute_hvg
from typing import Optional


Expand All @@ -13,7 +14,11 @@
"Smart-seq2).",
image="openproblems",
)
def immune_batch(test: bool = False, min_celltype_count: Optional[int] = None):
def immune_batch(
test: bool = False,
min_celltype_count: Optional[int] = None,
n_hvg: Optional[int] = None,
):
import scanpy as sc

adata = load_immune(test)
Expand All @@ -38,4 +43,7 @@ def immune_batch(test: bool = False, min_celltype_count: Optional[int] = None):

sc.pp.neighbors(adata, use_rep="X_uni_pca", key_added="uni")
adata.var_names_make_unique()

adata.uns["hvg_unint"] = precompute_hvg(adata, n_genes=n_hvg)
adata.uns["n_genes_pre"] = adata.n_vars
return adata
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .....data.pancreas import load_pancreas
from .....tools.decorators import dataset
from ..utils import filter_celltypes
from ..utils import precompute_hvg
from typing import Optional


Expand All @@ -13,7 +14,11 @@
"and SMARTER-seq).",
image="openproblems",
)
def pancreas_batch(test: bool = False, min_celltype_count: Optional[int] = None):
def pancreas_batch(
test: bool = False,
min_celltype_count: Optional[int] = None,
n_hvg: Optional[int] = None,
):
import scanpy as sc

adata = load_pancreas(test)
Expand All @@ -38,4 +43,7 @@ def pancreas_batch(test: bool = False, min_celltype_count: Optional[int] = None)
sc.pp.neighbors(adata, use_rep="X_uni_pca", key_added="uni")

adata.var_names_make_unique()

adata.uns["hvg_unint"] = precompute_hvg(adata, n_genes=n_hvg)
adata.uns["n_genes_pre"] = adata.n_vars
return adata
18 changes: 17 additions & 1 deletion openproblems/tasks/_batch_integration/_common/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
from . import api
from scanpy.pp import highly_variable_genes
from typing import Optional


def filter_celltypes(adata, min_celltype_count: Optional[int] = None):

min_celltype_count = min_celltype_count or 50
min_celltype_count = min_celltype_count or api.MIN_CELLS_PER_CELLTYPE

celltype_counts = adata.obs["labels"].value_counts()
keep_celltypes = celltype_counts[celltype_counts >= min_celltype_count].index
keep_cells = adata.obs["labels"].isin(keep_celltypes)
return adata[keep_cells].copy()


def precompute_hvg(adata, n_genes: Optional[int] = None):

n_genes = n_genes or api.N_HVG_UNINT
hvg_unint = highly_variable_genes(
adata,
n_top_genes=n_genes,
layer="log_normalized",
flavor="cell_ranger",
batch_key="batch",
inplace=False,
)
return list(hvg_unint[hvg_unint.highly_variable].index)
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ Datasets should contain the following attributes:
* `adata.layers['counts']` with raw, integer UMI count data,
* `adata.layers['log_normalized']` with log-normalized data and
* `adata.X` with log-normalized data
* `adata.uns['n_genes_pre']` with the number of genes present before integration
* `adata.uns['hvg_unint']` with a list of 2000 highly variable genes
prior to integration (for the hvg conservation metric)

Methods should store their a batch-corrected gene expression matrix in `adata.X`.
The output should should contain at least 2000 features.

The `openproblems-python-batch-integration` docker container is used for the methods
that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@

import functools

check_dataset = functools.partial(api.check_dataset, do_check_pca=True)
check_dataset = functools.partial(
api.check_dataset, do_check_hvg=True, do_check_pca=True
)


def check_method(adata, is_baseline=False):
"""Check that method output fits expected API."""
assert "log_normalized" in adata.layers
# check hvg_unint is still there
assert "hvg_unint" in adata.uns
# check n_vars is not too small
assert "n_genes_pre" in adata.uns
assert adata.n_vars >= min(api.N_HVG_UNINT, adata.uns["n_genes_pre"])
if not is_baseline:
assert adata.layers["log_normalized"] is not adata.X
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ def hvg_conservation(adata):

adata_unint = adata.copy()
adata_unint.X = adata_unint.layers["log_normalized"]
hvg_both = list(set(adata.uns["hvg_unint"]).intersection(adata.var_names))

return hvg_overlap(adata_unint, adata, "batch")
return hvg_overlap(adata_unint, adata[:, hvg_both], "batch")

0 comments on commit 49c83bf

Please sign in to comment.