-
Notifications
You must be signed in to change notification settings - Fork 76
Use upstream NequipTorchSimModel #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@cw-tan Can you give this a look over and okay? |
| r_max: float, | ||
| type_names: list[str], | ||
| device: torch.device | None = None, | ||
| neighbor_list_fn: Callable = torchsim_nl, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here this the main regression in so far as we loose the ability to vary the nl if using the nequip implementation upstream.
| """Checkpoint download URLs for NequIP models.""" | ||
|
|
||
| Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth" | ||
| # Cache directory for compiled models (under tests/ for easy cleanup) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introduced just to speed up repeat testing locally. Nequip tests will take a long time in CI due to needing to compile step. Should consider if we can cache compiled models in CI also.
| assert model._device == DEVICE # noqa: SLF001 | ||
|
|
||
|
|
||
| # NOTE: we take [:-1] to skip benzene. This is because the stress calculation in NequIP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benzene test failed on github CI runner. Similar numerical issue existed for SevenNet when PBC turned off. Stress is not meaningful when PBC is off.
Signed-off-by: Rhys Goodall <[email protected]>
Signed-off-by: Rhys Goodall <[email protected]>
Use the TorchSim Calc from https://github.com/mir-group/nequip rather than duplicating code. Not clear how to get dtype from AOTInductor, upstream to match the contract they just set the dtype to float64. Adds warning for end users. Implement basic cache for the compiled models to speed up repeat tests locally.