Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from statsmodels.stats.multitest import fdrcorrection

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy.tools.core import _is_raw_counts

if TYPE_CHECKING:
from matplotlib.axes import Axes
Expand Down Expand Up @@ -87,6 +88,7 @@ def load(
self,
input: AnnData | pd.DataFrame,
*,
layer: str | None = None,
meta: pd.DataFrame | None = None,
label_col: str = "label_col",
cell_type_col: str = "cell_type_col",
Expand All @@ -98,6 +100,7 @@ def load(
Args:
input: Anndata or matrix containing gene expression values (genes in rows, cells in columns)
and optionally meta data about each cell.
layer: Layer in AnnData to use for expression data. If None, uses .X
meta: Optional Pandas DataFrame containing meta data about each cell.
label_col: column of the meta DataFrame or the Anndata or matrix containing the condition labels for each cell
in the cell-by-gene expression matrix
Expand All @@ -114,11 +117,11 @@ def load(
>>> import pertpy as pt
>>> adata = pt.dt.sc_sim_augur()
>>> ag_rfc = pt.tl.Augur("random_forest_classifier")
>>> loaded_data = ag_rfc.load(adata)
>>> augur_adata = ag_rfc.load(adata)
"""
if isinstance(input, AnnData):
input.obs = input.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})
adata = input
obs_renamed = adata.obs.rename(columns={cell_type_col: "cell_type", label_col: "label"})

elif isinstance(input, pd.DataFrame):
if meta is None:
Expand All @@ -130,27 +133,47 @@ def load(

label = input[label_col] if meta is None else meta[label_col]
cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
x = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
adata = AnnData(X=x, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
X = input.drop([label_col, cell_type_col], axis=1) if meta is None else input
adata = AnnData(X=X, obs=pd.DataFrame({"cell_type": cell_type, "label": label}))
obs_renamed = adata.obs

if len(adata.obs["label"].unique()) < 2:
if len(obs_renamed["label"].unique()) < 2:
raise ValueError("Less than two unique labels in dataset. At least two are needed for the analysis.")

if isinstance(input, AnnData):
final_adata = AnnData(X=adata.X, obs=obs_renamed, var=adata.var, layers=adata.layers)
else:
final_adata = adata

# dummy variables for categorical data
if adata.obs["label"].dtype.name == "category":
# filter samples according to label
if final_adata.obs["label"].dtype.name == "category":
label_encoder = LabelEncoder()
final_adata.obs["y_"] = label_encoder.fit_transform(final_adata.obs["label"])

if condition_label is not None and treatment_label is not None:
logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
adata = ad.concat(
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
final_adata = ad.concat(
[
final_adata[final_adata.obs["label"] == condition_label],
final_adata[final_adata.obs["label"] == treatment_label],
]
)
label_encoder = LabelEncoder()
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
else:
y = adata.obs["label"].to_frame()
y = final_adata.obs["label"].to_frame()
y = y.rename(columns={"label": "y_"})
adata.obs = pd.concat([adata.obs, y], axis=1)
final_adata.obs = pd.concat([final_adata.obs, y], axis=1)

return adata
if layer is not None:
if layer not in final_adata.layers:
raise ValueError(f"Layer '{layer}' not found in AnnData object")
X = final_adata.layers[layer]
else:
X = final_adata.X

if not _is_raw_counts(X):
logger.warning("Data does not appear to be raw counts. Augur developers recommend using raw counts.")

return final_adata

def create_estimator(
self,
Expand Down
18 changes: 18 additions & 0 deletions pertpy/tools/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
from scipy import sparse


def _is_raw_counts(X: np.ndarray | sparse.spmatrix) -> bool:
"""Check if data appears to be raw counts."""
if sparse.issparse(X):
sample = X[:1000, :1000] if X.shape[0] > 1000 else X
data = sample.data
else:
sample = X[:1000, :1000] if X.shape[0] > 1000 else X
data = sample.ravel()

non_zero_data = data[data > 0]
if len(non_zero_data) == 0:
return True

return np.all(data >= 0) and np.all(data == np.round(data))
2 changes: 1 addition & 1 deletion tests/tools/test_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_differential_prioritization():
adata = pt.dt.sc_sim_augur()
adata = sc.pp.subsample(adata, n_obs=500, copy=True, random_state=10)
ag = pt.tl.Augur("logistic_regression_classifier", random_state=42)
ag.load(adata)
adata = ag.load(adata)

adata, results1 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=2)
adata, results2 = ag.predict(adata, n_threads=4, n_subsamples=3, random_state=42)
Expand Down
58 changes: 58 additions & 0 deletions tests/tools/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import pytest
from pertpy.tools.core import _is_raw_counts
from scipy import sparse


@pytest.mark.parametrize(
"data,expected",
[
# Dense arrays - positive cases
(np.array([[1, 2, 3], [4, 0, 5]]), True), # integers with zeros
(np.array([[0, 0], [0, 0]]), True), # all zeros
(np.array([[1, 2], [3, 4]]), True), # positive integers
(np.array([[100, 200], [300, 400]]), True), # larger integers
# Dense arrays - negative cases
(np.array([[1.5, 2.0], [3.0, 4.5]]), False), # floats
(np.array([[1, 2.1], [3, 4]]), False), # mixed int/float
(np.array([[-1, 2], [3, 4]]), False), # negative values
(np.log1p(np.array([[1, 2], [3, 4]])), False), # log-transformed
# Edge cases
(np.array([[0]]), True), # single zero
(np.array([[1]]), True), # single positive integer
(np.array([[1.0]]), True), # float that equals integer
],
)
def test_dense_arrays(data, expected):
assert _is_raw_counts(data) == expected


@pytest.mark.parametrize("sparse_type", [sparse.csr_matrix, sparse.csc_matrix, sparse.coo_matrix])
def test_sparse_arrays_positive(sparse_type):
dense_data = np.array([[1, 0, 3], [0, 5, 0], [2, 0, 4]])
sparse_data = sparse_type(dense_data)
assert _is_raw_counts(sparse_data)


@pytest.mark.parametrize("sparse_type", [sparse.csr_matrix, sparse.csc_matrix, sparse.coo_matrix])
def test_sparse_arrays_negative(sparse_type):
dense_data = np.array([[1.5, 0, 3.2], [0, 5.7, 0]])
sparse_data = sparse_type(dense_data)
assert not _is_raw_counts(sparse_data)


def test_large_array_sampling():
large_data = np.random.default_rng().integers(0, 100, size=(2000, 2000))
assert _is_raw_counts(large_data)


def test_large_sparse_array_sampling():
dense_data = np.random.default_rng().integers(0, 10, size=(2000, 2000))
dense_data[dense_data < 7] = 0
sparse_data = sparse.csr_matrix(dense_data)
assert _is_raw_counts(sparse_data)


def test_empty_sparse_matrix():
sparse_data = sparse.csr_matrix((100, 100))
assert _is_raw_counts(sparse_data)
Loading