diff --git a/README.md b/README.md index 196ee40..9281cb2 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,9 @@ or if you have a `uv.lock` file: uv sync --extra dev --editable ``` +**GPU-Accelerated Conformer Generation** +GPU accerlation for conformer generation is optionally available via `nvMolKit`. Instructions for installation this can be found [here](https://github.com/NVIDIA-Digital-Bio/nvMolKit). **Note** this has not been applied to the MMFF94 geometry optimisation because it is not currently compatible with the way we have defined our convergence criteria. + ## The Protocol The protocol used in StrainRelief is designed to be simple, fast and model agnostic - all that is needed to apply a new force field is to write an ASE calculator wrapper. Additionally you can use any MACE model, such as these from the [MACE-OFF23](https://github.com/ACEsuit/mace-off/tree/main/mace_off23) repository. @@ -76,7 +79,7 @@ StrainRelief runs are configured using hydra configs. ``` from strain_relief import compute_strain -strains = compute_strain(poses: list[RDKit.Mol], config: DictConfig) +computed = compute_strain(poses: list[RDKit.Mol], config: DictConfig) for i, r in computed.iterrows(): print(f"Pose {r['id']} has a strain of {r['ligand_strain']:.2f} kcal/mol") @@ -99,7 +102,7 @@ More examples are given [here](./examples/examples.sh), including the command us **RDKit kwargs** The following dictionaries are passed directly to the function of that name. -- `conformers` (`EmbedMultipleConfs`) +- `conformers.EmbedMultipleConfs` - `minimisation.MMFFGetMoleculeProperties` - `minimisation.MMFFGetMoleculeForceField` - `energy_eval.MMFFGetMoleculeProperties` @@ -108,8 +111,9 @@ The following dictionaries are passed directly to the function of that name. The hydra config is set up to allow additional kwargs to be passed to these functions e.g. `+minimisation.MMFFGetMoleculeProperties.mmffVerbosity=1`. **Common kwargs** -- `threshold` (set by default to 16.1 kcal/mol - calibrated using [LigBoundConf 2.0](https://huggingface.co/datasets/erwallace/LigBoundConf2.0)) -- `conformers.numConfs` +- `threshold`: set by default to 16.1 kcal/mol - calibrated using [LigBoundConf 2.0](https://huggingface.co/datasets/erwallace/LigBoundConf2.0) +- `conformers.EmbedMultipleConfs.numConfs` +- `conformers.device`: determines whether to use `RDKit` or GPU-accelerated `nvMolKit` - `global_min.maxIters`/`local_min.maxIters` - `global_min.fmax`/`local_min.maxIters` - `io.input.include_charged` diff --git a/src/strain_relief/conformers/_rdkit_generation.py b/src/strain_relief/conformers/_rdkit_generation.py index da21b4f..703d786 100644 --- a/src/strain_relief/conformers/_rdkit_generation.py +++ b/src/strain_relief/conformers/_rdkit_generation.py @@ -1,5 +1,6 @@ from collections import Counter from timeit import default_timer as timer +from typing import Literal import numpy as np from loguru import logger as logging @@ -11,13 +12,8 @@ def generate_conformers( mols: MolsDict, - randomSeed: int = -1, - numConfs: int = 10, - maxAttempts: int = 200, - pruneRmsThresh: float = 0.1, - clearConfs: bool = False, - numThreads: int = 0, - **kwargs, + EmbedMultipleConfs: dict, + device: Literal["cpu", "cuda"], ) -> MolsDict: """Generate conformers for a molecule. The 0th conformer is the original molecule. @@ -28,18 +24,11 @@ def generate_conformers( ---------- mols : MolsDict Nested dictionary of molecules for which to generate conformers. - randomSeed : int, optional - The random seed to use. The default is -1. - numConfs : int, optional - The number of conformers to generate. The default is 100. - maxAttempts : int, optional - The maximum number of attempts to try embedding. The default is 1000. - pruneRmsThresh : float, optional - The RMS threshold to prune conformers. The default is 0.1. - numThreads : int, optional - The number of threads to use while embedding. This only has an effect if the - RDKit was built with multi-thread support. If set to zero, the max supported - by the system will be used. The default is 0. + EmbedMultipleConfs : dict + Additional keyword arguments to pass to the EmbedMultipleConfs function. + For example: `numConfs`, `maxAttempts`, `pruneRmsThresh` and `randomSeed`. + device : Literal["cpu", "cuda"] + Device to run the conformer generation on (determines whether to use RDKit or nvMolKit). Returns ------- @@ -57,27 +46,26 @@ def generate_conformers( logging.info("Generating conformers...") + # Add bonds if missing for id, mol_properties in mols.items(): mol = mol_properties[MOL_KEY] charge = mol_properties[CHARGE_KEY] if mol.GetNumBonds() == 0: logging.debug(f"Adding bonds to {id}") rdDetermineBonds.DetermineBonds(mol, charge=charge) - AllChem.EmbedMultipleConfs( - mol, - randomSeed=randomSeed, - numConfs=numConfs, - maxAttempts=maxAttempts, - pruneRmsThresh=pruneRmsThresh, - clearConfs=clearConfs, - numThreads=numThreads, - **kwargs, - ) - logging.debug(f"{mol.GetNumConformers()} conformers generated for {id}") + + # Generate conformers + if device == "cuda": + _generate_conformers_cuda(mols, **EmbedMultipleConfs) + elif device == "cpu": + _generate_conformers_cpu(mols, **EmbedMultipleConfs) + else: + raise ValueError(f"Unknown device: {device}") n_conformers = np.array( [mol_properties[MOL_KEY].GetNumConformers() for mol_properties in mols.values()] ) + numConfs = EmbedMultipleConfs["numConfs"] if "numConfs" in EmbedMultipleConfs else 10 logging.info( f"{np.sum(n_conformers == numConfs + 1)} molecules with {numConfs + 1} conformers each" ) @@ -90,3 +78,36 @@ def generate_conformers( logging.info(f"Conformer generation took {end - start:.2f} seconds. \n") return mols + + +def _generate_conformers_cuda(mols, **kwargs): + """nvMolKit based conformer generation on GPU.""" + logging.info("Generating conformers with GPU enabled nvMolKit...") + try: + from nvmolkit.embedMolecules import EmbedMolecules as nvMolKitEmbed + except ImportError: + raise ImportError( + "nvMolKit is required for GPU based conformer generation. " + "Install from https://github.com/NVIDIA-Digital-Bio/nvMolKit " + "or set cfg.conformers.device = 'cpu' to use RDKit conformer generation." + ) + + mol_list = [mol_properties[MOL_KEY] for mol_properties in mols.values()] + nvMolKitEmbed(mol_list, **kwargs) + for i, id in enumerate(mols.keys()): + mols[id][MOL_KEY] = mol_list[i] + logging.debug(f"{mols[id][MOL_KEY].GetNumConformers()} conformers generated for {id}") + return mols + + +def _generate_conformers_cpu(mols, **kwargs): + """RDKit based conformer generation on CPU.""" + logging.info("Generating conformers with CPU enabled RDKit...") + for id, mol_properties in mols.items(): + mol = mol_properties[MOL_KEY] + AllChem.EmbedMultipleConfs( + mol, + **kwargs, + ) + logging.debug(f"{mol.GetNumConformers()} conformers generated for {id}") + return mols diff --git a/src/strain_relief/hydra_config/conformers/default.yaml b/src/strain_relief/hydra_config/conformers/default.yaml index a2df752..1ed84dd 100644 --- a/src/strain_relief/hydra_config/conformers/default.yaml +++ b/src/strain_relief/hydra_config/conformers/default.yaml @@ -1,6 +1,8 @@ -randomSeed: ${seed} -numConfs: 20 -maxAttempts: 10 -pruneRmsThresh: 0.1 -clearConfs: False -numThreads: ${numThreads} +device: ${device} +EmbedMultipleConfs: + randomSeed: ${seed} + numConfs: 20 + maxAttempts: 10 + pruneRmsThresh: 0.1 + clearConfs: False + numThreads: ${numThreads} diff --git a/src/strain_relief/hydra_config/default.yaml b/src/strain_relief/hydra_config/default.yaml index a3206e3..1f77b97 100644 --- a/src/strain_relief/hydra_config/default.yaml +++ b/src/strain_relief/hydra_config/default.yaml @@ -11,3 +11,4 @@ defaults: seed: -1 threshold: 16.1 numThreads: 0 +device: cuda diff --git a/src/strain_relief/hydra_config/experiment/mmff94s.yaml b/src/strain_relief/hydra_config/experiment/mmff94s.yaml index 14f178f..fb0b5bc 100644 --- a/src/strain_relief/hydra_config/experiment/mmff94s.yaml +++ b/src/strain_relief/hydra_config/experiment/mmff94s.yaml @@ -4,6 +4,8 @@ defaults: - override /minimisation@local_min: mmff94s - override /minimisation@global_min: mmff94s +device: cpu + conformers: numConfs: 20 diff --git a/src/strain_relief/hydra_config/model/fairchem.yaml b/src/strain_relief/hydra_config/model/fairchem.yaml index e5d7b52..2a98230 100644 --- a/src/strain_relief/hydra_config/model/fairchem.yaml +++ b/src/strain_relief/hydra_config/model/fairchem.yaml @@ -3,4 +3,4 @@ energy_units: eV calculator_kwargs: model_paths: ${..model_paths} default_dtype: float32 - device: cuda + device: ${device} diff --git a/src/strain_relief/hydra_config/model/mace.yaml b/src/strain_relief/hydra_config/model/mace.yaml index e5d7b52..2a98230 100644 --- a/src/strain_relief/hydra_config/model/mace.yaml +++ b/src/strain_relief/hydra_config/model/mace.yaml @@ -3,4 +3,4 @@ energy_units: eV calculator_kwargs: model_paths: ${..model_paths} default_dtype: float32 - device: cuda + device: ${device}