@@ -100,8 +100,8 @@ def robust_rmsd( # noqa: PLR0913
100100 RemoveStereochemistry (mol_ref )
101101
102102 if heavy_only :
103- mol_probe = RemoveHs (mol_probe )
104- mol_ref = RemoveHs (mol_ref )
103+ mol_probe = RemoveHs (mol_probe , sanitize = False )
104+ mol_ref = RemoveHs (mol_ref , sanitize = False )
105105
106106 # combine parameters
107107 params = dict (symmetrizeConjugatedTerminalGroups = symmetrizeConjugatedTerminalGroups , kabsch = kabsch , ** params )
@@ -186,11 +186,15 @@ def intercentroid(
186186 mol_probe : Mol , mol_ref : Mol , conf_id_probe : int = - 1 , conf_id_ref : int = - 1 , heavy_only : bool = True
187187) -> float :
188188 """Distance between centroids of two molecules."""
189- if heavy_only :
190- mol_probe = RemoveHs (mol_probe )
191- mol_ref = RemoveHs (mol_ref )
192-
193- centroid_probe = mol_probe .GetConformer (conf_id_probe ).GetPositions ().mean (axis = 0 )
194- centroid_ref = mol_ref .GetConformer (conf_id_ref ).GetPositions ().mean (axis = 0 )
195189
190+ centroid_probe = get_centroid (mol_probe , heavy_only , conf_id_probe )
191+ centroid_ref = get_centroid (mol_ref , heavy_only , conf_id_ref )
196192 return float (np .linalg .norm (centroid_probe - centroid_ref ))
193+
194+
195+ def get_centroid (mol : Mol , heavy_only : bool = True , conf_id : int = - 1 ) -> np .ndarray :
196+ """Get centroid of molecule."""
197+ pos = mol .GetConformer (conf_id ).GetPositions ()
198+ if heavy_only :
199+ pos = pos [[atom .GetAtomicNum () != 1 for atom in mol .GetAtoms ()], :]
200+ return pos .mean (axis = 0 )
0 commit comments