Skip to content

Commit 5080c8e

Browse files
authored
Add parallel processing (#65)
Runner: * Add parallelization. Use multiple processes to test files and chunks of molecules in parallel. Choose previous single threaded execution by setting `max-workers=0` in the API or `--max-workers=0` in the CLI. Output format: * Leave molecule name empty if no molecule name provided. * Add molecule index (position in SDF file) to output.
1 parent 4cd64bd commit 5080c8e

7 files changed

Lines changed: 255 additions & 148 deletions

File tree

posebusters/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@
3232
"check_volume_overlap",
3333
]
3434

35-
__version__ = "0.4.1"
35+
__version__ = "0.4.2"

posebusters/cli.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from yaml import safe_load
1515

1616
from . import __version__
17-
from .posebusters import PoseBusters, _dataframe_from_output
17+
from .posebusters import PoseBusters
1818
from .tools.formatting import create_long_output, create_short_output
1919

2020
logger = logging.getLogger(__name__)
@@ -40,6 +40,8 @@ def bust( # noqa: PLR0913
4040
no_header: bool = False,
4141
full_report: bool = False,
4242
top_n: int | None = None,
43+
max_workers: bool = False,
44+
chunk_size: int | None = None,
4345
):
4446
"""PoseBusters: Plausibility checks for generated molecule poses."""
4547
if table is None and len(mol_pred) == 0:
@@ -49,23 +51,23 @@ def bust( # noqa: PLR0913
4951
# run on table
5052
file_paths = pd.read_csv(table, index_col=None)
5153
mode = _select_mode(config, file_paths.columns.tolist())
52-
posebusters = PoseBusters(mode, top_n=top_n)
54+
posebusters = PoseBusters(mode, top_n=top_n, max_workers=max_workers, chunk_size=chunk_size)
5355
posebusters.file_paths = file_paths
5456
posebusters_results = posebusters._run()
5557
else:
5658
# run on single input
5759
d = {k for k, v in dict(mol_pred=mol_pred, mol_true=mol_true, mol_cond=mol_cond).items() if v}
5860
mode = _select_mode(config, d)
59-
posebusters = PoseBusters(mode, top_n=top_n)
61+
posebusters = PoseBusters(mode, top_n=top_n, max_workers=max_workers, chunk_size=chunk_size)
6062
cols = ["mol_pred", "mol_true", "mol_cond"]
6163
posebusters.file_paths = pd.DataFrame([[mol_pred, mol_true, mol_cond] for mol_pred in mol_pred], columns=cols)
6264
posebusters_results = posebusters._run()
6365

6466
if isinstance(output, Path):
6567
output = open(Path(output), "w", encoding="utf-8")
6668

67-
for i, results_dict in enumerate(posebusters_results):
68-
results = _dataframe_from_output(results_dict, posebusters.config, full_report)
69+
for i, (k, v) in enumerate(posebusters_results):
70+
results = posebusters._make_table({k: v}, posebusters.config, full_report=full_report)
6971
output.write(_format_results(results, outfmt, no_header, i))
7072

7173

@@ -99,6 +101,14 @@ def _parse_args(args):
99101
cfg_group.add_argument(
100102
"--top-n", type=int, default=None, help="run on TOP_N results in MOL_PRED only (default: all)"
101103
)
104+
cfg_group.add_argument(
105+
"--max-workers",
106+
type=int,
107+
help="number workers for parallel processing. (0: single thread, default: use all available cores)",
108+
)
109+
cfg_group.add_argument(
110+
"--chunk-size", type=int, help="chunk size for parallel processing of SDF files (default: 100)", default=100
111+
)
102112

103113
# other
104114
inf_group.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}")
@@ -124,7 +134,7 @@ def _format_results(df: pd.DataFrame, outfmt: str = "short", no_header: bool = F
124134

125135
if outfmt == "csv":
126136
header = (not no_header) and (index == 0)
127-
df.index.names = ["file", "molecule"]
137+
df.index.names = ["file", "molecule", "position"]
128138
df.columns = [c.lower().replace(" ", "_") for c in df.columns]
129139
return df.to_csv(index=True, header=header)
130140

posebusters/posebusters.py

Lines changed: 173 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import inspect
66
import logging
7-
from collections import defaultdict
87
from collections.abc import Generator, Iterable
8+
from concurrent.futures import ProcessPoolExecutor, as_completed
9+
from concurrent.futures.process import BrokenProcessPool
910
from functools import partial
11+
from math import ceil
1012
from pathlib import Path
1113
from typing import Any, Callable
1214

@@ -29,7 +31,7 @@
2931
)
3032
from .modules.sucos import check_sucos
3133
from .modules.volume_overlap import check_volume_overlap
32-
from .tools.loading import safe_load_mol, safe_supply_mols
34+
from .tools.loading import get_num_mols, safe_load_mol, safe_supply_mols
3335

3436
logger = logging.getLogger(__name__)
3537

@@ -51,6 +53,11 @@
5153
}
5254
molecule_args = {"mol_cond", "mol_true", "mol_pred"}
5355

56+
ResultKey = tuple[str, str, int]
57+
ResultList = list[tuple[str, str, Any]]
58+
ResultTuple = tuple[ResultKey, ResultList]
59+
ResultDict = dict[ResultKey, ResultList]
60+
5461

5562
class PoseBusters:
5663
"""Class to run all tests on a set of molecules."""
@@ -61,8 +68,23 @@ class PoseBusters:
6168
module_args: list
6269
fname: list
6370

64-
def __init__(self, config: str | dict[str, Any] = "redock", top_n: int | None = None):
65-
"""Initialize PoseBusters object."""
71+
def __init__(
72+
self,
73+
config: str | dict[str, Any] = "redock",
74+
top_n: int | None = None,
75+
max_workers: int | None = None,
76+
chunk_size: int | None = 100,
77+
) -> None:
78+
"""Initialize PoseBusters object.
79+
80+
Args:
81+
config: Configuration file or dictionary. If a string, it should be one of "dock", "redock", "mol", "gen".
82+
top_n: Number of poses to process. If None, all poses are processed.
83+
max_workers: Maximum number of workers for parallelization. If None, all available cores are used. If 0 or
84+
negative, no parallelization is used.
85+
chunk_size: Number of poses to process per process if parallelization is used. If None, parallelization over
86+
files only.
87+
"""
6688
self.module_func: list # dict[str, Callable]
6789
self.module_args: list # dict[str, set[str]]
6890

@@ -78,8 +100,8 @@ def __init__(self, config: str | dict[str, Any] = "redock", top_n: int | None =
78100
assert len(set(self.config.get("tests", {}).keys()) - set(module_dict.keys())) == 0
79101

80102
self.config["top_n"] = self.config.get("top_n", top_n)
81-
82-
self.results: dict[tuple[str, str], list[tuple[str, str, Any]]] = defaultdict(list)
103+
self.config["max_workers"] = self.config.get("max_workers", max_workers)
104+
self.config["chunk_size"] = self.config.get("chunk_size", chunk_size)
83105

84106
def bust(
85107
self,
@@ -106,14 +128,9 @@ def bust(
106128

107129
columns = ["mol_pred", "mol_true", "mol_cond"]
108130
self.file_paths = pd.DataFrame([[mol_pred, mol_true, mol_cond] for mol_pred in mol_pred_list], columns=columns)
109-
110-
results_gen = self._run()
111-
112-
df = pd.concat([_dataframe_from_output(d, self.config, full_report=full_report) for d in results_gen])
113-
df.index.names = ["file", "molecule"]
114-
df.columns = [c.lower().replace(" ", "_") for c in df.columns]
115-
116-
return df
131+
generator = self._run()
132+
results = self._collect_in_table(generator, full_report=full_report)
133+
return results
117134

118135
def bust_table(self, mol_table: pd.DataFrame, full_report: bool = False) -> pd.DataFrame:
119136
"""Run tests on molecules provided in pandas dataframe as paths or rdkit molecule objects.
@@ -126,59 +143,129 @@ def bust_table(self, mol_table: pd.DataFrame, full_report: bool = False) -> pd.D
126143
Pandas dataframe with results.
127144
"""
128145
self.file_paths = mol_table
146+
generator = self._run()
147+
results = self._collect_in_table(generator, full_report=full_report)
148+
return results
129149

130-
results_gen = self._run()
131-
132-
df = pd.concat([_dataframe_from_output(d, self.config, full_report=full_report) for d in results_gen])
133-
df.index.names = ["file", "molecule"]
134-
df.columns = [c.lower().replace(" ", "_") for c in df.columns]
135-
136-
return df
137-
138-
def _run(self) -> Generator[dict, None, None]:
150+
def _run(self) -> Generator[ResultTuple, None, None]:
139151
"""Run all tests on molecules provided in file paths.
140152
141153
Yields:
142154
Generator of result dictionaries.
143155
"""
144156
self._initialize_modules()
157+
max_workers = self.config.get("max_workers", None)
158+
chunk_size = self.config.get("chunk_size", 100)
159+
if max_workers is not None and max_workers <= 0:
160+
yield from self._run_single_thread()
161+
elif chunk_size is None:
162+
yield from self._run_parallel_over_files(max_workers=max_workers)
163+
else:
164+
yield from self._run_parallel_over_poses(max_workers=max_workers, chunk_size=chunk_size)
145165

166+
def _run_single_thread(self) -> Generator[ResultTuple, None, None]:
146167
for _, paths in self.file_paths.iterrows():
147-
mol_args = {}
148-
if "mol_cond" in paths and paths["mol_cond"] is not None:
149-
mol_cond_load_params = self.config.get("loading", {}).get("mol_cond", {})
150-
mol_args["mol_cond"] = safe_load_mol(path=paths["mol_cond"], **mol_cond_load_params)
151-
if "mol_true" in paths and paths["mol_true"] is not None:
152-
mol_true_load_params = self.config.get("loading", {}).get("mol_true", {})
153-
mol_args["mol_true"] = safe_load_mol(path=paths["mol_true"], **mol_true_load_params)
154-
155-
mol_pred_load_params = self.config.get("loading", {}).get("mol_pred", {})
156-
for i, mol_pred in enumerate(safe_supply_mols(paths["mol_pred"], **mol_pred_load_params)):
157-
if self.config["top_n"] is not None and i >= self.config["top_n"]:
158-
break
159-
160-
mol_args["mol_pred"] = mol_pred
161-
162-
results_key = (str(paths["mol_pred"]), self._get_name(mol_pred, i))
163-
164-
for name, fname, func, args in zip(self.module_name, self.fname, self.module_func, self.module_args):
165-
# pick needed arguments for module
166-
args_needed = {k: v for k, v in mol_args.items() if k in args}
167-
# loading takes all inputs
168-
if fname == "loading":
169-
args_needed = {k: args_needed.get(k, None) for k in args_needed}
170-
# run module when all needed input molecules are valid Mol objects
171-
if fname != "loading" and not all(args_needed.get(m, None) for m in args_needed):
172-
module_output: dict[str, Any] = {"results": {}}
173-
else:
174-
module_output = func(**args_needed)
175-
176-
# save to object
177-
self.results[results_key].extend([(name, k, v) for k, v in module_output["results"].items()])
178-
# self.results[results_key]["details"].append(module_output["details"])
179-
180-
# return results for this entry
181-
yield {results_key: self.results[results_key]}
168+
yield from self._run_multiple_poses(paths)
169+
170+
def _run_parallel_over_files(
171+
self, timeout: int | None = None, max_workers: int | None = None
172+
) -> Generator[ResultTuple, None, None]:
173+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
174+
futures = [executor.submit(self._run_and_combine, paths) for _, paths in self.file_paths.iterrows()]
175+
for future in as_completed(futures, timeout=None):
176+
try:
177+
results = future.result(timeout=timeout)
178+
except BrokenProcessPool as exception:
179+
# logger.critical("BrokenProcessPool: %s", exception)
180+
raise exception
181+
except Exception as exception:
182+
# logger.critical("Error in process: %s", exception)
183+
raise exception
184+
185+
yield from results
186+
187+
def _run_parallel_over_poses(
188+
self, timeout: int | None = None, max_workers: int | None = None, chunk_size: int = 100
189+
) -> Generator[ResultTuple, None, None]:
190+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
191+
futures = []
192+
for _, paths in self.file_paths.iterrows():
193+
num_mols_pred = get_num_mols(paths["mol_pred"])
194+
for chunk in range(ceil(num_mols_pred / chunk_size)):
195+
indices = range(chunk * chunk_size, min((chunk + 1) * chunk_size, num_mols_pred))
196+
future = executor.submit(self._run_and_combine, paths=paths, indices=indices)
197+
futures.append(future)
198+
199+
for future in as_completed(futures, timeout=None):
200+
try:
201+
results = future.result(timeout=timeout)
202+
except BrokenProcessPool as exception:
203+
# logger.critical("BrokenProcessPool: %s", exception)
204+
raise exception
205+
except Exception as exception:
206+
# logger.critical("Error in process: %s", exception)
207+
raise exception
208+
209+
yield from results
210+
211+
def _run_and_combine(self, paths: pd.Series, indices: Iterable[int] | None = None) -> list[ResultTuple]:
212+
"""Run and collect all tests for all poses in the prediction file."""
213+
return list(self._run_multiple_poses(paths, indices=indices))
214+
215+
def _run_multiple_poses(
216+
self, paths: pd.Series, indices: Iterable[int] | None = None
217+
) -> Generator[ResultTuple, None, None]:
218+
"""Run all tests on indexed poses in the prediction file.
219+
220+
Args:
221+
paths: Pandas series with keys "mol_pred", "mol_true", "mol_cond" containing paths to molecules.
222+
indices: Indices of poses to process. If None, all poses are processed.
223+
224+
Yields:
225+
Generator of result dictionaries.
226+
"""
227+
228+
mol_args = {}
229+
if "mol_cond" in paths and paths["mol_cond"] is not None:
230+
mol_cond_load_params = self.config.get("loading", {}).get("mol_cond", {})
231+
mol_args["mol_cond"] = safe_load_mol(path=paths["mol_cond"], **mol_cond_load_params)
232+
if "mol_true" in paths and paths["mol_true"] is not None:
233+
mol_true_load_params = self.config.get("loading", {}).get("mol_true", {})
234+
mol_args["mol_true"] = safe_load_mol(path=paths["mol_true"], **mol_true_load_params)
235+
236+
mol_pred_load_params = self.config.get("loading", {}).get("mol_pred", {})
237+
for i, mol_pred in enumerate(safe_supply_mols(paths["mol_pred"], indices=indices, **mol_pred_load_params)):
238+
if self.config["top_n"] is not None and i >= self.config["top_n"]:
239+
break
240+
mol_args["mol_pred"] = mol_pred
241+
242+
key: ResultKey = (str(paths["mol_pred"]), self._get_name(mol_pred), i)
243+
results: ResultList = self._run_one_pose(mol_args)
244+
245+
yield key, results
246+
247+
def _run_one_pose(self, molecules: dict[str, Any]) -> ResultList:
248+
"""Run all tests on a single pose."""
249+
results = []
250+
for name, fname, func, args in zip(self.module_name, self.fname, self.module_func, self.module_args):
251+
# pick needed arguments for module
252+
args_needed = {k: v for k, v in molecules.items() if k in args}
253+
254+
# loading takes all inputs
255+
if fname == "loading":
256+
args_needed = {k: args_needed.get(k, None) for k in args_needed}
257+
258+
# run module when all needed input molecules are valid Mol objects
259+
if fname != "loading" and not all(args_needed.get(m, None) for m in args_needed):
260+
module_output: dict[str, Any] = {"results": {}}
261+
else:
262+
module_output = func(**args_needed)
263+
264+
# save to object
265+
results.extend([(name, k, v) for k, v in module_output["results"].items()])
266+
# self.results[results_key]["details"].append(module_output["details"])
267+
268+
return results
182269

183270
def _initialize_modules(self) -> None:
184271
self.module_name = []
@@ -196,31 +283,39 @@ def _initialize_modules(self) -> None:
196283
self.module_args.append(module_args)
197284

198285
@staticmethod
199-
def _get_name(mol: Mol, i: int) -> str:
200-
if mol is None:
201-
return f"invalid_mol_at_pos_{i}"
286+
def _get_name(mol: Mol) -> str:
287+
"""Get the name of a molecule from the RDKit molecule object. Returns empty string if no name found."""
288+
if mol is None or not mol.HasProp("_Name"):
289+
return ""
290+
return mol.GetProp("_Name")
202291

203-
if not mol.HasProp("_Name") or mol.GetProp("_Name") == "":
204-
return f"mol_at_pos_{i}"
292+
def _collect_in_table(self, results_gen, full_report) -> pd.DataFrame:
293+
"""Collect generator results in a pandas dataframe."""
205294

206-
return mol.GetProp("_Name")
295+
df = pd.concat([self._make_table({k: v}, self.config, full_report=full_report) for k, v in results_gen])
296+
df.index.names = ["file", "molecule", "position"]
297+
df.columns = [c.lower().replace(" ", "_") for c in df.columns]
207298

299+
return df
300+
301+
@staticmethod
302+
def _make_table(results_dict: ResultDict, config, full_report: bool = False) -> pd.DataFrame:
303+
"""Generate a table from the output of the tests."""
208304

209-
def _dataframe_from_output(results_dict, config, full_report: bool = False) -> pd.DataFrame:
210-
d = {id: {(module, output): value for module, output, value in results} for id, results in results_dict.items()}
211-
df = pd.DataFrame.from_dict(d, orient="index")
305+
d = {id: {(module, output): value for module, output, value in results} for id, results in results_dict.items()}
306+
df = pd.DataFrame.from_dict(d, orient="index")
212307

213-
test_columns = [(c["name"], n) for c in config["modules"] for n in c.get("chosen_binary_test_output", [])]
214-
names_lookup = {(c["name"], k): v for c in config["modules"] for k, v in c.get("rename_outputs", {}).items()}
215-
suffix_lookup = {c["name"]: c["rename_suffix"] for c in config["modules"] if "rename_suffix" in c}
308+
test_columns = [(c["name"], n) for c in config["modules"] for n in c.get("chosen_binary_test_output", [])]
309+
names_lookup = {(c["name"], k): v for c in config["modules"] for k, v in c.get("rename_outputs", {}).items()}
310+
suffix_lookup = {c["name"]: c["rename_suffix"] for c in config["modules"] if "rename_suffix" in c}
216311

217-
available_columns = df.columns.tolist()
218-
missing_columns = [c for c in test_columns if c not in available_columns]
219-
extra_columns = [c for c in available_columns if c not in test_columns]
220-
columns = test_columns + extra_columns if full_report else test_columns
312+
available_columns = df.columns.tolist()
313+
missing_columns = [c for c in test_columns if c not in available_columns]
314+
extra_columns = [c for c in available_columns if c not in test_columns]
315+
columns = test_columns + extra_columns if full_report else test_columns
221316

222-
df[missing_columns] = pd.NA
223-
df = df[columns]
224-
df.columns = [names_lookup.get(c, c[-1] + suffix_lookup.get(c[0], "")) for c in df.columns]
317+
df[missing_columns] = pd.NA
318+
df = df[columns]
319+
df.columns = [names_lookup.get(c, c[-1] + suffix_lookup.get(c[0], "")) for c in df.columns]
225320

226-
return df
321+
return df

0 commit comments

Comments
 (0)