Skip to content

Commit

Permalink
Allow more pass through arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Dec 21, 2023
1 parent f2107d9 commit 9de93e9
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 12 deletions.
2 changes: 0 additions & 2 deletions docs/modules/cutoff.rst

This file was deleted.

2 changes: 2 additions & 0 deletions docs/modules/defaults.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: tad_multicharge.defaults
:members:
2 changes: 1 addition & 1 deletion docs/modules/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ The following modules are contained with `tad_multicharge`.
.. toctree::

param/index
cutoff
defaults
eeq
model
13 changes: 10 additions & 3 deletions src/tad_multicharge/cutoff.py → src/tad_multicharge/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see <https://www.gnu.org/licenses/>.
"""
Cutoff
======
Defaults
========
Real-space cutoffs for the coordination number within the EEQ Model.
Default global parameters of the charge models.
- EEQ: real-space cutoffs for the coordination number
- EEQ: Steepness of CN counting function
"""

EEQ_CN_CUTOFF = 25.0
"""Coordination number cutoff within EEQ (25.0)."""

EEQ_CN_MAX = 8.0
"""Maximum coordination number (8.0)."""

EEQ_KCN = 7.5
"""Steepness of counting function in EEQ model (7.5)."""
36 changes: 30 additions & 6 deletions src/tad_multicharge/eeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@
import torch
from tad_mctc import storch
from tad_mctc.batch import real_atoms, real_pairs
from tad_mctc.ncoord import cn_eeq
from tad_mctc.typing import DD, Tensor
from tad_mctc.ncoord import cn_eeq, erf_count
from tad_mctc.typing import DD, Any, CountingFunction, Tensor

from . import defaults
from .model import ChargeModel
from .param import eeq2019

Expand Down Expand Up @@ -240,7 +241,13 @@ def get_eeq(
numbers: Tensor,
positions: Tensor,
chrg: Tensor,
cutoff: Tensor | None = None,
*,
counting_function: CountingFunction = erf_count,
rcov: Tensor | None = None,
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
kcn: Tensor | float | int = defaults.EEQ_KCN,
**kwargs: Any,
) -> tuple[Tensor, Tensor]:
"""
Calculate atomic EEQ charges and energies.
Expand All @@ -253,16 +260,33 @@ def get_eeq(
Cartesian coordinates of the atoms in the system (batch, natoms, 3).
chrg : Tensor
Total charge of system.
cutoff : Tensor | None, optional
Real-space cutoff. Defaults to `None`.
counting_function : CountingFunction
Calculate weight for pairs. Defaults to `erf_count`.
rcov : Tensor | None, optional
Covalent radii for each species. Defaults to `None`.
cutoff : Tensor | float | int | None, optional
Real-space cutoff. Defaults to `defaults.CUTOFF_EEQ`.
cn_max : Tensor | float | int | None, optional
Maximum coordination number. Defaults to `defaults.CUTOFF_EEQ_MAX`.
kcn : Tensor | float | int, optional
Steepness of the counting function.
Returns
-------
(Tensor, Tensor)
Tuple of electrostatic energies and partial charges.
"""
eeq = EEQModel.param2019(device=positions.device, dtype=positions.dtype)
cn = cn_eeq(numbers, positions, cutoff=cutoff)
cn = cn_eeq(
numbers,
positions,
counting_function=counting_function,
rcov=rcov,
cutoff=cutoff,
cn_max=cn_max,
kcn=kcn,
**kwargs,
)
return solve(numbers, positions, chrg, eeq, cn)


Expand Down

0 comments on commit 9de93e9

Please sign in to comment.