Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class GuideAssignment:
"""Offers simple guide assigment based on count thresholds."""
"""Offers Simple guide assigment based on count thresholds."""

def assign_by_threshold(
self,
Expand Down
30 changes: 15 additions & 15 deletions pertpy/tools/_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,27 @@ def _get_pseudobulks(

return pseudobulk

def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50) -> pd.DataFrame:
"""Return cell-averaged PCA components.
def _pseudobulk_feature(
self, adata: AnnData, groupby: str, n_components: int = 50, feature_key: str = "X_pca"
) -> pd.DataFrame:
"""Return Cell-averaged components from a custom feature space.

TODO: consider merging with `get_pseudobulks`
TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.

Args:
groupby: The key to groupby for pseudobulks
n_components: The number of PCA components
groupby: The key to groupby for pseudobulks.
n_components: The number of components to use.
feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").

Returns:
A pseudobulk of PCA components.
A pseudobulk DataFrame of the averaged components.
"""
aggr = {}

for category in adata.obs.loc[:, groupby].cat.categories:
temp = adata.obs.loc[:, groupby] == category
aggr[category] = adata[temp].obsm["X_pca"][:, :n_components].mean(axis=0)

aggr[category] = adata[temp].obsm[feature_key][:, :n_components].mean(axis=0)
aggr = pd.DataFrame(aggr)

return aggr

def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
Expand Down Expand Up @@ -558,7 +558,7 @@ def _load(
self,
adata: AnnData,
ct_order: list[str],
agg_pca: bool = True,
agg_feature: bool = True,
normalize: bool = True,
) -> tuple[list, dict]:
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
Expand All @@ -568,14 +568,14 @@ def _load(
Args:
adata: AnnData object generate celltype objects for
ct_order: The order of cell types
agg_pca: Whether to aggregate pseudobulks with PCA or not.
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
normalize: Whether to mimic DIALOGUE behavior or not.

Returns:
A celltype_label:array dictionary.
"""
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
fn = self._pseudobulk_feature if agg_feature else self._get_pseudobulks
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore

# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
Expand All @@ -593,7 +593,7 @@ def calculate_multifactor_PMD(
adata: AnnData,
penalties: list[int] = None,
ct_order: list[str] = None,
agg_pca: bool = True,
agg_feature: bool = True,
solver: Literal["lp", "bs"] = "bs",
normalize: bool = True,
) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
Expand All @@ -606,7 +606,7 @@ def calculate_multifactor_PMD(
sample_id: Key to use for pseudobulk determination.
penalties: PMD penalties.
ct_order: The order of cell types.
agg_pca: Whether to calculate cell-averaged PCA components.
agg_features: Whether to calculate cell-averaged principal components.
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
normalize: Whether to mimic DIALOGUE as close as possible
Expand All @@ -631,7 +631,7 @@ def calculate_multifactor_PMD(
else:
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories

mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)

n_samples = mcca_in[0].shape[1]
if penalties is None:
Expand Down
Loading