44
55import inspect
66import logging
7- from collections import defaultdict
87from collections .abc import Generator , Iterable
8+ from concurrent .futures import ProcessPoolExecutor , as_completed
9+ from concurrent .futures .process import BrokenProcessPool
910from functools import partial
11+ from math import ceil
1012from pathlib import Path
1113from typing import Any , Callable
1214
2931)
3032from .modules .sucos import check_sucos
3133from .modules .volume_overlap import check_volume_overlap
32- from .tools .loading import safe_load_mol , safe_supply_mols
34+ from .tools .loading import get_num_mols , safe_load_mol , safe_supply_mols
3335
3436logger = logging .getLogger (__name__ )
3537
5153}
5254molecule_args = {"mol_cond" , "mol_true" , "mol_pred" }
5355
56+ ResultKey = tuple [str , str , int ]
57+ ResultList = list [tuple [str , str , Any ]]
58+ ResultTuple = tuple [ResultKey , ResultList ]
59+ ResultDict = dict [ResultKey , ResultList ]
60+
5461
5562class PoseBusters :
5663 """Class to run all tests on a set of molecules."""
@@ -61,8 +68,23 @@ class PoseBusters:
6168 module_args : list
6269 fname : list
6370
64- def __init__ (self , config : str | dict [str , Any ] = "redock" , top_n : int | None = None ):
65- """Initialize PoseBusters object."""
71+ def __init__ (
72+ self ,
73+ config : str | dict [str , Any ] = "redock" ,
74+ top_n : int | None = None ,
75+ max_workers : int | None = None ,
76+ chunk_size : int | None = 100 ,
77+ ) -> None :
78+ """Initialize PoseBusters object.
79+
80+ Args:
81+ config: Configuration file or dictionary. If a string, it should be one of "dock", "redock", "mol", "gen".
82+ top_n: Number of poses to process. If None, all poses are processed.
83+ max_workers: Maximum number of workers for parallelization. If None, all available cores are used. If 0 or
84+ negative, no parallelization is used.
85+ chunk_size: Number of poses to process per process if parallelization is used. If None, parallelization over
86+ files only.
87+ """
6688 self .module_func : list # dict[str, Callable]
6789 self .module_args : list # dict[str, set[str]]
6890
@@ -78,8 +100,8 @@ def __init__(self, config: str | dict[str, Any] = "redock", top_n: int | None =
78100 assert len (set (self .config .get ("tests" , {}).keys ()) - set (module_dict .keys ())) == 0
79101
80102 self .config ["top_n" ] = self .config .get ("top_n" , top_n )
81-
82- self .results : dict [ tuple [ str , str ], list [ tuple [ str , str , Any ]]] = defaultdict ( list )
103+ self . config [ "max_workers" ] = self . config . get ( "max_workers" , max_workers )
104+ self .config [ "chunk_size" ] = self . config . get ( "chunk_size" , chunk_size )
83105
84106 def bust (
85107 self ,
@@ -106,14 +128,9 @@ def bust(
106128
107129 columns = ["mol_pred" , "mol_true" , "mol_cond" ]
108130 self .file_paths = pd .DataFrame ([[mol_pred , mol_true , mol_cond ] for mol_pred in mol_pred_list ], columns = columns )
109-
110- results_gen = self ._run ()
111-
112- df = pd .concat ([_dataframe_from_output (d , self .config , full_report = full_report ) for d in results_gen ])
113- df .index .names = ["file" , "molecule" ]
114- df .columns = [c .lower ().replace (" " , "_" ) for c in df .columns ]
115-
116- return df
131+ generator = self ._run ()
132+ results = self ._collect_in_table (generator , full_report = full_report )
133+ return results
117134
118135 def bust_table (self , mol_table : pd .DataFrame , full_report : bool = False ) -> pd .DataFrame :
119136 """Run tests on molecules provided in pandas dataframe as paths or rdkit molecule objects.
@@ -126,59 +143,129 @@ def bust_table(self, mol_table: pd.DataFrame, full_report: bool = False) -> pd.D
126143 Pandas dataframe with results.
127144 """
128145 self .file_paths = mol_table
146+ generator = self ._run ()
147+ results = self ._collect_in_table (generator , full_report = full_report )
148+ return results
129149
130- results_gen = self ._run ()
131-
132- df = pd .concat ([_dataframe_from_output (d , self .config , full_report = full_report ) for d in results_gen ])
133- df .index .names = ["file" , "molecule" ]
134- df .columns = [c .lower ().replace (" " , "_" ) for c in df .columns ]
135-
136- return df
137-
138- def _run (self ) -> Generator [dict , None , None ]:
150+ def _run (self ) -> Generator [ResultTuple , None , None ]:
139151 """Run all tests on molecules provided in file paths.
140152
141153 Yields:
142154 Generator of result dictionaries.
143155 """
144156 self ._initialize_modules ()
157+ max_workers = self .config .get ("max_workers" , None )
158+ chunk_size = self .config .get ("chunk_size" , 100 )
159+ if max_workers is not None and max_workers <= 0 :
160+ yield from self ._run_single_thread ()
161+ elif chunk_size is None :
162+ yield from self ._run_parallel_over_files (max_workers = max_workers )
163+ else :
164+ yield from self ._run_parallel_over_poses (max_workers = max_workers , chunk_size = chunk_size )
145165
166+ def _run_single_thread (self ) -> Generator [ResultTuple , None , None ]:
146167 for _ , paths in self .file_paths .iterrows ():
147- mol_args = {}
148- if "mol_cond" in paths and paths ["mol_cond" ] is not None :
149- mol_cond_load_params = self .config .get ("loading" , {}).get ("mol_cond" , {})
150- mol_args ["mol_cond" ] = safe_load_mol (path = paths ["mol_cond" ], ** mol_cond_load_params )
151- if "mol_true" in paths and paths ["mol_true" ] is not None :
152- mol_true_load_params = self .config .get ("loading" , {}).get ("mol_true" , {})
153- mol_args ["mol_true" ] = safe_load_mol (path = paths ["mol_true" ], ** mol_true_load_params )
154-
155- mol_pred_load_params = self .config .get ("loading" , {}).get ("mol_pred" , {})
156- for i , mol_pred in enumerate (safe_supply_mols (paths ["mol_pred" ], ** mol_pred_load_params )):
157- if self .config ["top_n" ] is not None and i >= self .config ["top_n" ]:
158- break
159-
160- mol_args ["mol_pred" ] = mol_pred
161-
162- results_key = (str (paths ["mol_pred" ]), self ._get_name (mol_pred , i ))
163-
164- for name , fname , func , args in zip (self .module_name , self .fname , self .module_func , self .module_args ):
165- # pick needed arguments for module
166- args_needed = {k : v for k , v in mol_args .items () if k in args }
167- # loading takes all inputs
168- if fname == "loading" :
169- args_needed = {k : args_needed .get (k , None ) for k in args_needed }
170- # run module when all needed input molecules are valid Mol objects
171- if fname != "loading" and not all (args_needed .get (m , None ) for m in args_needed ):
172- module_output : dict [str , Any ] = {"results" : {}}
173- else :
174- module_output = func (** args_needed )
175-
176- # save to object
177- self .results [results_key ].extend ([(name , k , v ) for k , v in module_output ["results" ].items ()])
178- # self.results[results_key]["details"].append(module_output["details"])
179-
180- # return results for this entry
181- yield {results_key : self .results [results_key ]}
168+ yield from self ._run_multiple_poses (paths )
169+
170+ def _run_parallel_over_files (
171+ self , timeout : int | None = None , max_workers : int | None = None
172+ ) -> Generator [ResultTuple , None , None ]:
173+ with ProcessPoolExecutor (max_workers = max_workers ) as executor :
174+ futures = [executor .submit (self ._run_and_combine , paths ) for _ , paths in self .file_paths .iterrows ()]
175+ for future in as_completed (futures , timeout = None ):
176+ try :
177+ results = future .result (timeout = timeout )
178+ except BrokenProcessPool as exception :
179+ # logger.critical("BrokenProcessPool: %s", exception)
180+ raise exception
181+ except Exception as exception :
182+ # logger.critical("Error in process: %s", exception)
183+ raise exception
184+
185+ yield from results
186+
187+ def _run_parallel_over_poses (
188+ self , timeout : int | None = None , max_workers : int | None = None , chunk_size : int = 100
189+ ) -> Generator [ResultTuple , None , None ]:
190+ with ProcessPoolExecutor (max_workers = max_workers ) as executor :
191+ futures = []
192+ for _ , paths in self .file_paths .iterrows ():
193+ num_mols_pred = get_num_mols (paths ["mol_pred" ])
194+ for chunk in range (ceil (num_mols_pred / chunk_size )):
195+ indices = range (chunk * chunk_size , min ((chunk + 1 ) * chunk_size , num_mols_pred ))
196+ future = executor .submit (self ._run_and_combine , paths = paths , indices = indices )
197+ futures .append (future )
198+
199+ for future in as_completed (futures , timeout = None ):
200+ try :
201+ results = future .result (timeout = timeout )
202+ except BrokenProcessPool as exception :
203+ # logger.critical("BrokenProcessPool: %s", exception)
204+ raise exception
205+ except Exception as exception :
206+ # logger.critical("Error in process: %s", exception)
207+ raise exception
208+
209+ yield from results
210+
211+ def _run_and_combine (self , paths : pd .Series , indices : Iterable [int ] | None = None ) -> list [ResultTuple ]:
212+ """Run and collect all tests for all poses in the prediction file."""
213+ return list (self ._run_multiple_poses (paths , indices = indices ))
214+
215+ def _run_multiple_poses (
216+ self , paths : pd .Series , indices : Iterable [int ] | None = None
217+ ) -> Generator [ResultTuple , None , None ]:
218+ """Run all tests on indexed poses in the prediction file.
219+
220+ Args:
221+ paths: Pandas series with keys "mol_pred", "mol_true", "mol_cond" containing paths to molecules.
222+ indices: Indices of poses to process. If None, all poses are processed.
223+
224+ Yields:
225+ Generator of result dictionaries.
226+ """
227+
228+ mol_args = {}
229+ if "mol_cond" in paths and paths ["mol_cond" ] is not None :
230+ mol_cond_load_params = self .config .get ("loading" , {}).get ("mol_cond" , {})
231+ mol_args ["mol_cond" ] = safe_load_mol (path = paths ["mol_cond" ], ** mol_cond_load_params )
232+ if "mol_true" in paths and paths ["mol_true" ] is not None :
233+ mol_true_load_params = self .config .get ("loading" , {}).get ("mol_true" , {})
234+ mol_args ["mol_true" ] = safe_load_mol (path = paths ["mol_true" ], ** mol_true_load_params )
235+
236+ mol_pred_load_params = self .config .get ("loading" , {}).get ("mol_pred" , {})
237+ for i , mol_pred in enumerate (safe_supply_mols (paths ["mol_pred" ], indices = indices , ** mol_pred_load_params )):
238+ if self .config ["top_n" ] is not None and i >= self .config ["top_n" ]:
239+ break
240+ mol_args ["mol_pred" ] = mol_pred
241+
242+ key : ResultKey = (str (paths ["mol_pred" ]), self ._get_name (mol_pred ), i )
243+ results : ResultList = self ._run_one_pose (mol_args )
244+
245+ yield key , results
246+
247+ def _run_one_pose (self , molecules : dict [str , Any ]) -> ResultList :
248+ """Run all tests on a single pose."""
249+ results = []
250+ for name , fname , func , args in zip (self .module_name , self .fname , self .module_func , self .module_args ):
251+ # pick needed arguments for module
252+ args_needed = {k : v for k , v in molecules .items () if k in args }
253+
254+ # loading takes all inputs
255+ if fname == "loading" :
256+ args_needed = {k : args_needed .get (k , None ) for k in args_needed }
257+
258+ # run module when all needed input molecules are valid Mol objects
259+ if fname != "loading" and not all (args_needed .get (m , None ) for m in args_needed ):
260+ module_output : dict [str , Any ] = {"results" : {}}
261+ else :
262+ module_output = func (** args_needed )
263+
264+ # save to object
265+ results .extend ([(name , k , v ) for k , v in module_output ["results" ].items ()])
266+ # self.results[results_key]["details"].append(module_output["details"])
267+
268+ return results
182269
183270 def _initialize_modules (self ) -> None :
184271 self .module_name = []
@@ -196,31 +283,39 @@ def _initialize_modules(self) -> None:
196283 self .module_args .append (module_args )
197284
198285 @staticmethod
199- def _get_name (mol : Mol , i : int ) -> str :
200- if mol is None :
201- return f"invalid_mol_at_pos_{ i } "
286+ def _get_name (mol : Mol ) -> str :
287+ """Get the name of a molecule from the RDKit molecule object. Returns empty string if no name found."""
288+ if mol is None or not mol .HasProp ("_Name" ):
289+ return ""
290+ return mol .GetProp ("_Name" )
202291
203- if not mol . HasProp ( "_Name" ) or mol . GetProp ( "_Name" ) == "" :
204- return f"mol_at_pos_ { i } "
292+ def _collect_in_table ( self , results_gen , full_report ) -> pd . DataFrame :
293+ """Collect generator results in a pandas dataframe."" "
205294
206- return mol .GetProp ("_Name" )
295+ df = pd .concat ([self ._make_table ({k : v }, self .config , full_report = full_report ) for k , v in results_gen ])
296+ df .index .names = ["file" , "molecule" , "position" ]
297+ df .columns = [c .lower ().replace (" " , "_" ) for c in df .columns ]
207298
299+ return df
300+
301+ @staticmethod
302+ def _make_table (results_dict : ResultDict , config , full_report : bool = False ) -> pd .DataFrame :
303+ """Generate a table from the output of the tests."""
208304
209- def _dataframe_from_output (results_dict , config , full_report : bool = False ) -> pd .DataFrame :
210- d = {id : {(module , output ): value for module , output , value in results } for id , results in results_dict .items ()}
211- df = pd .DataFrame .from_dict (d , orient = "index" )
305+ d = {id : {(module , output ): value for module , output , value in results } for id , results in results_dict .items ()}
306+ df = pd .DataFrame .from_dict (d , orient = "index" )
212307
213- test_columns = [(c ["name" ], n ) for c in config ["modules" ] for n in c .get ("chosen_binary_test_output" , [])]
214- names_lookup = {(c ["name" ], k ): v for c in config ["modules" ] for k , v in c .get ("rename_outputs" , {}).items ()}
215- suffix_lookup = {c ["name" ]: c ["rename_suffix" ] for c in config ["modules" ] if "rename_suffix" in c }
308+ test_columns = [(c ["name" ], n ) for c in config ["modules" ] for n in c .get ("chosen_binary_test_output" , [])]
309+ names_lookup = {(c ["name" ], k ): v for c in config ["modules" ] for k , v in c .get ("rename_outputs" , {}).items ()}
310+ suffix_lookup = {c ["name" ]: c ["rename_suffix" ] for c in config ["modules" ] if "rename_suffix" in c }
216311
217- available_columns = df .columns .tolist ()
218- missing_columns = [c for c in test_columns if c not in available_columns ]
219- extra_columns = [c for c in available_columns if c not in test_columns ]
220- columns = test_columns + extra_columns if full_report else test_columns
312+ available_columns = df .columns .tolist ()
313+ missing_columns = [c for c in test_columns if c not in available_columns ]
314+ extra_columns = [c for c in available_columns if c not in test_columns ]
315+ columns = test_columns + extra_columns if full_report else test_columns
221316
222- df [missing_columns ] = pd .NA
223- df = df [columns ]
224- df .columns = [names_lookup .get (c , c [- 1 ] + suffix_lookup .get (c [0 ], "" )) for c in df .columns ]
317+ df [missing_columns ] = pd .NA
318+ df = df [columns ]
319+ df .columns = [names_lookup .get (c , c [- 1 ] + suffix_lookup .get (c [0 ], "" )) for c in df .columns ]
225320
226- return df
321+ return df
0 commit comments