Skip to content

Commit 7ad359f

Browse files
authored
Merge pull request #3 from dschaub95/main
Add spatial stability and more clustering algos
2 parents d873342 + c8c72a0 commit 7ad359f

File tree

9 files changed

+353
-90
lines changed

9 files changed

+353
-90
lines changed

poetry.lock

Lines changed: 74 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ dependencies = [
2121
"jupyter",
2222
"leidenalg",
2323
"igraph",
24-
"harmonypy (>=0.0.10,<0.0.11)"
24+
"harmonypy (>=0.0.10,<0.0.11)",
25+
"louvain (>=0.8.2,<0.9.0)"
2526
]
2627

2728

scale/clustering.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
homogeneity_score,
1212
completeness_score,
1313
)
14+
from sklearn.cluster import KMeans
15+
from scipy.sparse import issparse
1416

1517
from scale.config import Config
1618

@@ -21,11 +23,13 @@ def calc_clusterings(
2123
n_jobs=20,
2224
ensure_unique=True,
2325
emb_prefix="X_gnn",
26+
method="leiden",
2427
**kwargs,
2528
):
2629
resolutions = np.arange(
2730
cfg.resolution_set.start, cfg.resolution_set.stop, cfg.resolution_set.step
2831
).round(4)
32+
n_repeats = 1 if cfg.stability_spatial else cfg.n_repeats
2933

3034
all_clusterings = pd.DataFrame(index=adata.obs_names)
3135

@@ -47,11 +51,12 @@ def calc_clusterings(
4751
sparam_str = "dist" if "dist" in emb_key else "knn"
4852
sparam = emb_key.split(f"{sparam_str}_")[-1].split("_lam")[0]
4953

50-
for i in tqdm(range(cfg.n_repeats), desc="Calculating clusterings"):
51-
parallel_leiden(
54+
for i in tqdm(range(n_repeats), desc="Calculating clusterings"):
55+
parallel_clustering(
5256
ad_tmp,
5357
resolutions,
5458
key_added=f"leiden_rep_{i}_{sparam_str}_{sparam}",
59+
method=method,
5560
n_jobs=n_jobs,
5661
verbose=kwargs.get("verbose", False),
5762
random_state=i,
@@ -124,6 +129,71 @@ def loop(r, adata):
124129
return adata
125130

126131

132+
def parallel_clustering(
133+
adata,
134+
resolutions,
135+
method="leiden",
136+
key_added="scale",
137+
n_jobs=10,
138+
verbose=True,
139+
random_state=0,
140+
**kwargs,
141+
):
142+
def to_key(r):
143+
return key_added + "_res_" + str(r)
144+
145+
def loop(adata, r=None, **kwargs):
146+
if method == "leiden":
147+
flavor = kwargs.pop("flavor", "igraph")
148+
n_iterations = kwargs.pop("n_iterations", 2)
149+
key = to_key(r)
150+
sc.tl.leiden(
151+
adata,
152+
resolution=r,
153+
key_added=key,
154+
random_state=random_state,
155+
flavor=flavor,
156+
n_iterations=n_iterations,
157+
**kwargs,
158+
)
159+
if verbose:
160+
print(f"Resolution = {r} Done!")
161+
return adata.obs[key]
162+
elif method == "louvain":
163+
key = to_key(r)
164+
sc.tl.louvain(
165+
adata,
166+
resolution=r,
167+
key_added=key,
168+
random_state=random_state,
169+
flavor="vtraag",
170+
)
171+
if verbose:
172+
print(f"Resolution = {r} Done!")
173+
return adata.obs[key]
174+
elif method == "kmeans":
175+
k = int(r)
176+
X = adata.obsm["X_pca"] if "X_pca" in adata.obsm else adata.X
177+
X = X.A if issparse(X) else X
178+
km = KMeans(n_clusters=k, random_state=random_state)
179+
labels = km.fit_predict(X)
180+
key = to_key(k)
181+
adata.obs[key] = pd.Categorical(labels.astype(str))
182+
if verbose:
183+
print(f"K value = {k} Done!")
184+
return adata.obs[key]
185+
else:
186+
raise ValueError(f"Invalid method: {method}")
187+
188+
clusterings = Parallel(n_jobs=n_jobs)(
189+
delayed(loop)(adata, r, **kwargs) for r in resolutions
190+
)
191+
for clustering in clusterings:
192+
adata.obs[clustering.name] = clustering
193+
194+
return adata
195+
196+
127197
def calc_cluster_metrics(
128198
labels_true,
129199
labels_pred,

scale/config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass, field
22
from typing import Literal
33

4+
45
class BaseConfig(dict):
56
"""Dict that also supports attribute access (recursively)."""
67

@@ -66,14 +67,18 @@ class Config(BaseConfig):
6667
n_heads: int = 5
6768
max_epoch: int = 500
6869
lr: float = 0.01
69-
n_sample: int = None # number of maximum edges in case of distance graph (randomyl selected)
70+
n_sample: int = (
71+
None # number of maximum edges in case of distance graph (randomyl selected)
72+
)
7073
sample_key: str = None
7174
preprocess: bool = False
7275
device: str | None = None
7376
distance_set: dict | list = field(
7477
default_factory=lambda: {"start": 15, "stop": 60, "step": 5}
7578
)
76-
knn_set: dict | list = field(default_factory=lambda: {"start": 5, "stop": 40, "step": 5})
79+
knn_set: dict | list = field(
80+
default_factory=lambda: {"start": 5, "stop": 40, "step": 5}
81+
)
7782
lambda_set: list = field(
7883
default_factory=lambda: [
7984
1e-6,
@@ -100,6 +105,8 @@ class Config(BaseConfig):
100105
spatial_graph_method: Literal["distance", "knn"] = "distance"
101106
repeated_negative_sampling: bool = False
102107
y_aggregated: bool = False
108+
stability_spatial: bool = False
109+
stability_delta: float = 1.0
103110

104111

105112
def load_config(**kwargs):

scale/scale.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,14 @@ def run_scale(
107107
calc_clusterings(
108108
ad_tmp,
109109
cfg=cfg,
110+
method=kwargs.get("method", "leiden"),
110111
flavor=kwargs.get("flavor", "igraph"),
111112
n_iterations=kwargs.get("n_iterations", 2),
112113
)
113114

114115
calc_stability(
115116
ad_tmp,
117+
cfg=cfg,
116118
verbose=kwargs.get("verbose", True),
117119
n_repeat=kwargs.get("n_repeat", 4),
118120
min_dist=kwargs.get("min_dist", 15),
@@ -122,7 +124,7 @@ def run_scale(
122124
min_res=kwargs.get("min_res", None),
123125
max_res=kwargs.get("max_res", None),
124126
)
125-
results = calc_entropy(
127+
calc_entropy(
126128
ad_tmp,
127129
n_levels=kwargs.get("n_levels", 2),
128130
top_n=kwargs.get("top_n", 0.15),

0 commit comments

Comments
 (0)